first version
This commit is contained in:
+328
-13
@@ -2,38 +2,353 @@ package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"overnight-trading-bot/internal/config"
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/execution"
|
||||
"overnight-trading-bot/internal/features"
|
||||
"overnight-trading-bot/internal/healthcheck"
|
||||
"overnight-trading-bot/internal/instruments"
|
||||
"overnight-trading-bot/internal/logging"
|
||||
"overnight-trading-bot/internal/marketdata"
|
||||
"overnight-trading-bot/internal/notify"
|
||||
"overnight-trading-bot/internal/position"
|
||||
"overnight-trading-bot/internal/reconciliation"
|
||||
mysqlrepo "overnight-trading-bot/internal/repository/mysql"
|
||||
"overnight-trading-bot/internal/risk"
|
||||
"overnight-trading-bot/internal/scheduler"
|
||||
signalengine "overnight-trading-bot/internal/signal"
|
||||
"overnight-trading-bot/internal/statemachine"
|
||||
"overnight-trading-bot/internal/timeutil"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
ConfigPath string
|
||||
Stdout io.Writer
|
||||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
ModeOverride string
|
||||
Unhalt bool
|
||||
Reason string
|
||||
Healthcheck bool
|
||||
HealthcheckURL string
|
||||
RunOnce bool
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if opts.ConfigPath == "" {
|
||||
return errors.New("config path is required")
|
||||
if opts.Healthcheck {
|
||||
target := opts.HealthcheckURL
|
||||
if target == "" {
|
||||
target = "http://127.0.0.1:3300/health"
|
||||
}
|
||||
return healthcheck.CheckEndpoint(ctx, target)
|
||||
}
|
||||
|
||||
if opts.Stdout == nil {
|
||||
opts.Stdout = io.Discard
|
||||
}
|
||||
|
||||
if _, err := os.Stat(opts.ConfigPath); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("config file %q does not exist; copy config.example.yaml to config.yaml and fill credentials", opts.ConfigPath)
|
||||
if opts.Stderr == nil {
|
||||
opts.Stderr = io.Discard
|
||||
}
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("load ENV config: %w", err)
|
||||
}
|
||||
if opts.ModeOverride != "" {
|
||||
mode, err := domain.ParseMode(opts.ModeOverride)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.App.Mode = mode
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log := logging.New(cfg.App.LogLevel, opts.Stdout)
|
||||
log.Info("overnight trading bot starting", "mode", cfg.App.Mode)
|
||||
|
||||
return fmt.Errorf("check config file %q: %w", opts.ConfigPath, err)
|
||||
if cfg.App.Mode == domain.ModeBacktest && !opts.Unhalt {
|
||||
_, _ = fmt.Fprintf(opts.Stdout, "overnight trading bot initialized in %s mode\n", cfg.App.Mode)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(opts.Stdout, "overnight trading bot initialized with config %q\n", opts.ConfigPath)
|
||||
return nil
|
||||
db, err := openDB(ctx, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
if cfg.DB.MigrationsAutoApply {
|
||||
if err := mysqlrepo.ApplyMigrations(ctx, db.DB); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
repo := mysqlrepo.NewRepository(db)
|
||||
if opts.Unhalt {
|
||||
if strings.TrimSpace(opts.Reason) == "" {
|
||||
return errors.New("-unhalt requires -reason")
|
||||
}
|
||||
gateway, closer, err := buildGateway(ctx, cfg, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if closer != nil {
|
||||
defer closer()
|
||||
}
|
||||
accountIDHash := accountHash(cfg.TInvest.AccountID)
|
||||
recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash).
|
||||
WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours) * time.Hour).
|
||||
WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec) * time.Second)
|
||||
diffs, err := recon.Run(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pre-unhalt reconciliation: %w", err)
|
||||
}
|
||||
if reconciliation.HasCritical(diffs) {
|
||||
return fmt.Errorf("pre-unhalt reconciliation has critical diffs: %d", len(diffs))
|
||||
}
|
||||
if err := repo.Unhalt(ctx, opts.Reason); err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = fmt.Fprintf(opts.Stdout, "system unhalted: %s\n", opts.Reason)
|
||||
return nil
|
||||
}
|
||||
|
||||
gateway, closer, err := buildGateway(ctx, cfg, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if closer != nil {
|
||||
defer closer()
|
||||
}
|
||||
notifier, err := notify.NewTelegram(notify.TelegramConfig{
|
||||
BotToken: cfg.Telegram.BotToken,
|
||||
ChatID: cfg.Telegram.ChatID,
|
||||
NotifyInfo: cfg.Telegram.NotifyInfo,
|
||||
NotifyWarn: cfg.Telegram.NotifyWarn,
|
||||
NotifyAlert: cfg.Telegram.NotifyAlert,
|
||||
NotifyReport: cfg.Telegram.NotifyReport,
|
||||
AuditSink: repo,
|
||||
}, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create notifier: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = notifier.Close()
|
||||
}()
|
||||
|
||||
accountIDHash := accountHash(cfg.TInvest.AccountID)
|
||||
recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash).
|
||||
WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours) * time.Hour).
|
||||
WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec) * time.Second)
|
||||
sm := statemachine.New(repo, cfg.App.Mode)
|
||||
if _, err := sm.Recover(ctx, recon); err != nil {
|
||||
log.Warn("state recovery did not resume trading", "err", err)
|
||||
}
|
||||
health := healthcheck.New(db.DB, gateway, time.Duration(cfg.Risk.MaxClockDriftSec)*time.Second)
|
||||
health.Start(cfg.App.HealthcheckAddr)
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.App.ShutdownTimeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
_ = health.Shutdown(shutdownCtx)
|
||||
}()
|
||||
if err := notifier.Info(ctx, fmt.Sprintf("bot started in %s mode", cfg.App.Mode)); err != nil {
|
||||
log.Warn("notify startup failed", "err", err)
|
||||
}
|
||||
if opts.RunOnce {
|
||||
_, _ = fmt.Fprintf(opts.Stdout, "overnight trading bot initialized in %s mode\n", cfg.App.Mode)
|
||||
return nil
|
||||
}
|
||||
runCtx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
clock := timeutil.RealClock{Loc: cfg.Location}
|
||||
runtime := buildScheduler(clock, sm, cfg, repo, gateway, notifier, recon, accountIDHash, log)
|
||||
return runtime.Run(runCtx)
|
||||
}
|
||||
|
||||
func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Config, repo *mysqlrepo.Repository, gateway tinvest.Gateway, notifier notify.Notifier, recon reconciliation.Engine, accountIDHash string, log *slog.Logger) scheduler.Scheduler {
|
||||
registry := instruments.NewRegistry(repo, gateway)
|
||||
loader := marketdata.NewLoader(repo, gateway)
|
||||
pipeline := features.NewPipeline(repo, features.PipelineConfig{
|
||||
RollingShort: cfg.Strategy.RollingShort,
|
||||
RollingLong: cfg.Strategy.RollingLong,
|
||||
EWMALambda: cfg.Strategy.EWMALambda,
|
||||
RiskBufferBps: cfg.Strategy.RiskBufferBps,
|
||||
EntrySlippageBps: cfg.Backtest.EntrySlippageBps,
|
||||
ExitSlippageBps: cfg.Backtest.ExitSlippageBps,
|
||||
CommissionRoundtripBps: cfg.Backtest.CommissionRoundtripBps,
|
||||
EntryWindow: timeutil.Window{
|
||||
Start: cfg.Execution.EntryWindowStart,
|
||||
End: cfg.Execution.EntryWindowEnd,
|
||||
},
|
||||
ExitWindow: timeutil.Window{
|
||||
Start: cfg.Execution.ExitWindowStart,
|
||||
End: cfg.Execution.ExitWindowEnd,
|
||||
},
|
||||
Location: cfg.Location,
|
||||
})
|
||||
signalEngine := signalengine.New(signalengine.Config{
|
||||
MinTStat60: cfg.Strategy.MinTStat60,
|
||||
MinWinRate60: cfg.Strategy.MinWinRate60,
|
||||
MinNetEdgeBps: cfg.Strategy.MinNetEdgeBps,
|
||||
MinADVRUB: cfg.Liquidity.MinADVRUB,
|
||||
MaxSpreadBpsDefault: cfg.Liquidity.MaxSpreadBpsDefault,
|
||||
MaxSpreadBpsMoneyMarket: cfg.Liquidity.MaxSpreadBpsMoneyMarket,
|
||||
MaxSpreadBpsBondFunds: cfg.Liquidity.MaxSpreadBpsBondFunds,
|
||||
MaxSpreadBpsEquityFunds: cfg.Liquidity.MaxSpreadBpsEquityFunds,
|
||||
MaxTickBps: cfg.Liquidity.MaxTickBps,
|
||||
RequireZeroCommission: cfg.Commission.RequireZeroCommission,
|
||||
MaxPositions: cfg.Strategy.MaxPositions,
|
||||
})
|
||||
sizer := risk.NewSizer(risk.SizingConfig{
|
||||
MaxPositionPct: cfg.Risk.MaxPositionPct,
|
||||
MaxTotalExposurePct: cfg.Risk.MaxTotalExposurePct,
|
||||
MaxParticipationRate: cfg.Liquidity.MaxParticipationRate,
|
||||
CashUsageBuffer: cfg.Risk.CashUsageBuffer,
|
||||
RiskBudgetPerInstrumentPct: cfg.Risk.RiskBudgetPerInstrumentPct,
|
||||
MinOrderNotionalRUB: cfg.Risk.MinOrderNotionalRUB,
|
||||
})
|
||||
freeOrders := risk.NewFreeOrderBudget(repo)
|
||||
riskManager := risk.NewManager(repo, risk.ManagerConfig{
|
||||
MaxDailyLossPct: cfg.Risk.MaxDailyLossPct,
|
||||
MaxWeeklyLossPct: cfg.Risk.MaxWeeklyLossPct,
|
||||
MaxMonthlyDrawdownPct: cfg.Risk.MaxMonthlyDrawdownPct,
|
||||
MaxAvgSlippageBps10Trades: cfg.Risk.MaxAvgSlippageBps10Trades,
|
||||
MaxOpenPositions: cfg.Risk.MaxOpenPositions,
|
||||
MinTimeToClose: time.Duration(cfg.Execution.MinTimeToCloseSec) * time.Second,
|
||||
MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second,
|
||||
})
|
||||
execEngine := execution.NewEngine(cfg.App.Mode, cfg.TInvest.AccountID, gateway, repo)
|
||||
execEngine.SetMaxQuoteAge(time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second)
|
||||
services := scheduler.Services{
|
||||
Repo: repo,
|
||||
Gateway: gateway,
|
||||
Registry: registry,
|
||||
MarketData: loader,
|
||||
Features: pipeline,
|
||||
Signals: signalEngine,
|
||||
Sizer: sizer,
|
||||
FreeOrders: freeOrders,
|
||||
Risk: riskManager,
|
||||
Execution: &execEngine,
|
||||
Positions: position.NewManager(repo),
|
||||
Reconcile: recon,
|
||||
Notifier: notifier,
|
||||
AccountID: cfg.TInvest.AccountID,
|
||||
AccountIDHash: accountIDHash,
|
||||
Log: log,
|
||||
}
|
||||
return scheduler.New(clock, sm, scheduler.Config{
|
||||
Mode: cfg.App.Mode,
|
||||
Location: cfg.Location,
|
||||
RollingLong: cfg.Strategy.RollingLong,
|
||||
TickInterval: 30 * time.Second,
|
||||
EntrySignalTime: cfg.Execution.EntrySignalTime,
|
||||
EntryWindowStart: cfg.Execution.EntryWindowStart,
|
||||
EntryWindowEnd: cfg.Execution.EntryWindowEnd,
|
||||
NoNewEntryAfter: cfg.Execution.NoNewEntryAfter,
|
||||
ExitWatchStart: cfg.Execution.ExitWatchStart,
|
||||
ExitWindowStart: cfg.Execution.ExitWindowStart,
|
||||
ExitWindowEnd: cfg.Execution.ExitWindowEnd,
|
||||
HardExitDeadline: cfg.Execution.HardExitDeadline,
|
||||
QuoteDepth: cfg.Execution.QuoteDepth,
|
||||
MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second,
|
||||
OrderPollInterval: time.Duration(cfg.Execution.OrderPollIntervalMS) * time.Millisecond,
|
||||
PassiveImproveTicks: cfg.Execution.PassiveImproveTicks,
|
||||
MaxEntryOrderAttempts: cfg.Execution.MaxEntryOrderAttempts,
|
||||
MaxExitOrderAttempts: cfg.Execution.MaxExitOrderAttempts,
|
||||
MinTimeToClose: time.Duration(cfg.Execution.MinTimeToCloseSec) * time.Second,
|
||||
MaxClockDrift: time.Duration(cfg.Risk.MaxClockDriftSec) * time.Second,
|
||||
APIOutageHalt: time.Duration(cfg.Risk.APIOutageHaltSec) * time.Second,
|
||||
}, services)
|
||||
}
|
||||
|
||||
func openDB(ctx context.Context, cfg config.Config) (*sqlx.DB, error) {
|
||||
db, err := sqlx.Open("mysql", cfg.DB.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(cfg.DB.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.DB.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(time.Duration(cfg.DB.ConnMaxLifetimeMin) * time.Minute)
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("db ping: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func buildGateway(ctx context.Context, cfg config.Config, log *slog.Logger) (tinvest.Gateway, func(), error) {
|
||||
switch cfg.App.Mode {
|
||||
case domain.ModePaper:
|
||||
return tinvest.NewFakeGateway(), nil, nil
|
||||
case domain.ModeSandbox:
|
||||
gw, err := tinvest.NewSandboxGateway(ctx, tinvest.Options{
|
||||
Token: cfg.TInvest.Token,
|
||||
AccountID: cfg.TInvest.AccountID,
|
||||
AppName: cfg.TInvest.AppName,
|
||||
RetryCount: cfg.TInvest.RetryCount,
|
||||
RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second,
|
||||
Logger: log,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return gw, func() { _ = gw.Close() }, nil
|
||||
case domain.ModeLiveReadonly, domain.ModeLiveTrade:
|
||||
endpoint := cfg.TInvest.Endpoint
|
||||
if cfg.TInvest.UseSandbox {
|
||||
return nil, nil, errors.New("TINVEST_USE_SANDBOX is only allowed with APP_MODE=sandbox")
|
||||
}
|
||||
gw, err := tinvest.NewRealGateway(ctx, tinvest.Options{
|
||||
Token: cfg.TInvest.Token,
|
||||
AccountID: cfg.TInvest.AccountID,
|
||||
Endpoint: endpoint,
|
||||
AppName: cfg.TInvest.AppName,
|
||||
RetryCount: cfg.TInvest.RetryCount,
|
||||
RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second,
|
||||
Logger: log,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return gw, func() { _ = gw.Close() }, nil
|
||||
default:
|
||||
return tinvest.NewFakeGateway(), nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func accountHash(accountID string) string {
|
||||
sum := sha256.Sum256([]byte(accountID))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func HealthURL(addr string) string {
|
||||
if strings.HasPrefix(addr, ":") {
|
||||
return "http://127.0.0.1" + addr + "/health"
|
||||
}
|
||||
if _, err := url.ParseRequestURI(addr); err == nil && strings.HasPrefix(addr, "http") {
|
||||
return addr
|
||||
}
|
||||
return "http://" + addr + "/health"
|
||||
}
|
||||
|
||||
func PingDB(ctx context.Context, db *sql.DB) error {
|
||||
return db.PingContext(ctx)
|
||||
}
|
||||
|
||||
@@ -3,49 +3,29 @@ package app
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRunRequiresConfigPath(t *testing.T) {
|
||||
err := Run(context.Background(), Options{})
|
||||
func TestRunRequiresAppMode(t *testing.T) {
|
||||
t.Setenv("APP_MODE", "")
|
||||
err := Run(context.Background(), Options{RunOnce: true})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "config path is required") {
|
||||
if !strings.Contains(err.Error(), "APP_MODE") && !strings.Contains(err.Error(), "MODE") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunReportsMissingConfig(t *testing.T) {
|
||||
err := Run(context.Background(), Options{ConfigPath: "missing.yaml"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "does not exist") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunUsesExistingConfig(t *testing.T) {
|
||||
touch := t.TempDir() + "/config.yaml"
|
||||
if err := os.WriteFile(touch, []byte("instruments: []\n"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
func TestRunBacktestModeWithoutDB(t *testing.T) {
|
||||
t.Setenv("APP_MODE", "backtest")
|
||||
var stdout bytes.Buffer
|
||||
err := Run(context.Background(), Options{
|
||||
ConfigPath: touch,
|
||||
Stdout: &stdout,
|
||||
})
|
||||
err := Run(context.Background(), Options{Stdout: &stdout, RunOnce: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !strings.Contains(stdout.String(), "initialized") {
|
||||
t.Fatalf("unexpected stdout: %q", stdout.String())
|
||||
if !strings.Contains(stdout.String(), "backtest") {
|
||||
t.Fatalf("unexpected stdout: %s", stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,563 @@
|
||||
package backtest
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/features"
|
||||
"overnight-trading-bot/internal/money"
|
||||
"overnight-trading-bot/internal/risk"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
EntrySlippageBps decimal.Decimal
|
||||
ExitSlippageBps decimal.Decimal
|
||||
CommissionRoundtripBps decimal.Decimal
|
||||
InitialEquity decimal.Decimal
|
||||
OutputDir string
|
||||
RollingShort int
|
||||
RollingLong int
|
||||
EWMALambda float64
|
||||
MinTStat60 decimal.Decimal
|
||||
MinWinRate60 decimal.Decimal
|
||||
MinNetEdgeBps decimal.Decimal
|
||||
MinADVRUB decimal.Decimal
|
||||
MaxSpreadBps decimal.Decimal
|
||||
MaxTickBps decimal.Decimal
|
||||
RequireZeroCommission bool
|
||||
MaxPositions int
|
||||
MaxPositionPct decimal.Decimal
|
||||
MaxTotalExposurePct decimal.Decimal
|
||||
MaxParticipationRate decimal.Decimal
|
||||
CashUsageBuffer decimal.Decimal
|
||||
RiskBudgetPct decimal.Decimal
|
||||
MinOrderNotionalRUB decimal.Decimal
|
||||
AssumedSpreadBps decimal.Decimal
|
||||
AssumedTickBps decimal.Decimal
|
||||
Lot int64
|
||||
UseMinuteModel bool
|
||||
}
|
||||
|
||||
type Trade struct {
|
||||
InstrumentUID string `json:"instrument_uid"`
|
||||
EntryDate string `json:"entry_date"`
|
||||
ExitDate string `json:"exit_date"`
|
||||
BuyPrice decimal.Decimal `json:"buy_price"`
|
||||
SellPrice decimal.Decimal `json:"sell_price"`
|
||||
Return decimal.Decimal `json:"return"`
|
||||
Lots int64 `json:"lots"`
|
||||
Notional decimal.Decimal `json:"notional"`
|
||||
NetPnL decimal.Decimal `json:"net_pnl"`
|
||||
SpreadBps decimal.Decimal `json:"spread_bps"`
|
||||
SlippageBps decimal.Decimal `json:"slippage_bps"`
|
||||
OvernightGap decimal.Decimal `json:"overnight_gap"`
|
||||
CapacityRUB decimal.Decimal `json:"capacity_rub"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Metrics Metrics `json:"metrics"`
|
||||
Trades []Trade `json:"trades"`
|
||||
Equity []Point `json:"equity"`
|
||||
}
|
||||
|
||||
type Point struct {
|
||||
Date string `json:"date"`
|
||||
Equity decimal.Decimal `json:"equity"`
|
||||
Return decimal.Decimal `json:"return"`
|
||||
}
|
||||
|
||||
type Engine struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func New(cfg Config) Engine {
|
||||
cfg = cfg.withDefaults()
|
||||
return Engine{cfg: cfg}
|
||||
}
|
||||
|
||||
func (cfg Config) withDefaults() Config {
|
||||
if cfg.InitialEquity.IsZero() {
|
||||
cfg.InitialEquity = decimal.NewFromInt(100_000)
|
||||
}
|
||||
if cfg.RollingShort == 0 {
|
||||
cfg.RollingShort = 60
|
||||
}
|
||||
if cfg.RollingLong == 0 {
|
||||
cfg.RollingLong = 252
|
||||
}
|
||||
if cfg.EWMALambda == 0 {
|
||||
cfg.EWMALambda = 0.08
|
||||
}
|
||||
if cfg.MinTStat60.IsZero() {
|
||||
cfg.MinTStat60 = decimal.NewFromFloat(1.25)
|
||||
}
|
||||
if cfg.MinWinRate60.IsZero() {
|
||||
cfg.MinWinRate60 = decimal.NewFromFloat(0.55)
|
||||
}
|
||||
if cfg.MinNetEdgeBps.IsZero() {
|
||||
cfg.MinNetEdgeBps = decimal.NewFromInt(10)
|
||||
}
|
||||
if cfg.MinADVRUB.IsZero() {
|
||||
cfg.MinADVRUB = decimal.NewFromInt(5_000_000)
|
||||
}
|
||||
if cfg.MaxSpreadBps.IsZero() {
|
||||
cfg.MaxSpreadBps = decimal.NewFromInt(20)
|
||||
}
|
||||
if cfg.MaxTickBps.IsZero() {
|
||||
cfg.MaxTickBps = decimal.NewFromInt(10)
|
||||
}
|
||||
if !cfg.RequireZeroCommission && cfg.CommissionRoundtripBps.IsZero() {
|
||||
cfg.RequireZeroCommission = true
|
||||
}
|
||||
if cfg.MaxPositions == 0 {
|
||||
cfg.MaxPositions = 5
|
||||
}
|
||||
if cfg.MaxPositionPct.IsZero() {
|
||||
cfg.MaxPositionPct = decimal.NewFromFloat(0.10)
|
||||
}
|
||||
if cfg.MaxTotalExposurePct.IsZero() {
|
||||
cfg.MaxTotalExposurePct = decimal.NewFromFloat(0.50)
|
||||
}
|
||||
if cfg.MaxParticipationRate.IsZero() {
|
||||
cfg.MaxParticipationRate = decimal.NewFromFloat(0.01)
|
||||
}
|
||||
if cfg.CashUsageBuffer.IsZero() {
|
||||
cfg.CashUsageBuffer = decimal.NewFromFloat(0.95)
|
||||
}
|
||||
if cfg.RiskBudgetPct.IsZero() {
|
||||
cfg.RiskBudgetPct = decimal.NewFromFloat(0.005)
|
||||
}
|
||||
if cfg.MinOrderNotionalRUB.IsZero() {
|
||||
cfg.MinOrderNotionalRUB = decimal.NewFromInt(1000)
|
||||
}
|
||||
if cfg.Lot == 0 {
|
||||
cfg.Lot = 1
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (e Engine) Run(candlesByInstrument map[string][]domain.Candle) (Result, error) {
|
||||
return e.RunWithMinuteCandles(candlesByInstrument, nil)
|
||||
}
|
||||
|
||||
func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Candle, minuteCandlesByInstrument map[string][]domain.Candle) (Result, error) {
|
||||
prepared := prepareCandles(candlesByInstrument)
|
||||
preparedMinutes := prepareCandles(minuteCandlesByInstrument)
|
||||
candidatesByExitDate := make(map[string][]candidate)
|
||||
for instrumentUID, candles := range prepared {
|
||||
for i := 1; i < len(candles); i++ {
|
||||
candidate, ok, err := e.evaluateCandidate(instrumentUID, candles, i)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
if ok {
|
||||
candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")] = append(candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")], candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
dates := make([]string, 0, len(candidatesByExitDate))
|
||||
for date := range candidatesByExitDate {
|
||||
dates = append(dates, date)
|
||||
}
|
||||
sort.Strings(dates)
|
||||
|
||||
equity := e.cfg.InitialEquity
|
||||
cash := e.cfg.InitialEquity
|
||||
var trades []Trade
|
||||
points := []Point{{Date: "START", Equity: equity}}
|
||||
sizer := risk.NewSizer(risk.SizingConfig{
|
||||
MaxPositionPct: e.cfg.MaxPositionPct,
|
||||
MaxTotalExposurePct: e.cfg.MaxTotalExposurePct,
|
||||
MaxParticipationRate: e.cfg.MaxParticipationRate,
|
||||
CashUsageBuffer: e.cfg.CashUsageBuffer,
|
||||
RiskBudgetPerInstrumentPct: e.cfg.RiskBudgetPct,
|
||||
MinOrderNotionalRUB: e.cfg.MinOrderNotionalRUB,
|
||||
})
|
||||
for _, date := range dates {
|
||||
dayCandidates := candidatesByExitDate[date]
|
||||
sort.Slice(dayCandidates, func(i, j int) bool {
|
||||
if dayCandidates[i].netEdge.Equal(dayCandidates[j].netEdge) {
|
||||
return dayCandidates[i].instrumentUID < dayCandidates[j].instrumentUID
|
||||
}
|
||||
return dayCandidates[i].netEdge.GreaterThan(dayCandidates[j].netEdge)
|
||||
})
|
||||
if len(dayCandidates) > e.cfg.MaxPositions {
|
||||
dayCandidates = dayCandidates[:e.cfg.MaxPositions]
|
||||
}
|
||||
dayStartEquity := equity
|
||||
dayPnL := decimal.Zero
|
||||
for _, c := range dayCandidates {
|
||||
sized := sizer.Size(risk.SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: equity, Cash: cash},
|
||||
SelectedInstruments: len(dayCandidates),
|
||||
LimitPrice: c.buy,
|
||||
Lot: e.cfg.Lot,
|
||||
EntryIntervalVolume: c.adv,
|
||||
ExitIntervalVolume: c.adv,
|
||||
Q05OvernightAbs: c.q05Abs,
|
||||
})
|
||||
if sized.Lots <= 0 {
|
||||
continue
|
||||
}
|
||||
lots := sized.Lots
|
||||
capacity := c.capacity
|
||||
if e.cfg.UseMinuteModel {
|
||||
executedLots, minuteCapacity, ok := e.minuteExecution(c, preparedMinutes[c.instrumentUID], sized.Lots)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
lots = executedLots
|
||||
capacity = minuteCapacity
|
||||
}
|
||||
notional := c.buy.Mul(decimal.NewFromInt(lots)).Mul(decimal.NewFromInt(e.cfg.Lot))
|
||||
ret := c.sell.Div(c.buy).Sub(decimal.NewFromInt(1)).Sub(money.FromBps(e.cfg.CommissionRoundtripBps))
|
||||
pnl := notional.Mul(ret)
|
||||
dayPnL = dayPnL.Add(pnl)
|
||||
cash = cash.Sub(notional)
|
||||
trades = append(trades, Trade{
|
||||
InstrumentUID: c.instrumentUID,
|
||||
EntryDate: c.entry.TradeDate.Format("2006-01-02"),
|
||||
ExitDate: c.exit.TradeDate.Format("2006-01-02"),
|
||||
BuyPrice: c.buy,
|
||||
SellPrice: c.sell,
|
||||
Return: ret,
|
||||
Lots: lots,
|
||||
Notional: notional,
|
||||
NetPnL: pnl,
|
||||
SpreadBps: e.cfg.AssumedSpreadBps,
|
||||
SlippageBps: e.cfg.EntrySlippageBps.Add(e.cfg.ExitSlippageBps),
|
||||
OvernightGap: c.overnightGap,
|
||||
CapacityRUB: capacity,
|
||||
})
|
||||
}
|
||||
if !dayPnL.IsZero() {
|
||||
equity = equity.Add(dayPnL)
|
||||
cash = equity
|
||||
points = append(points, Point{
|
||||
Date: date,
|
||||
Equity: equity,
|
||||
Return: dayPnL.Div(dayStartEquity),
|
||||
})
|
||||
}
|
||||
}
|
||||
sort.Slice(trades, func(i, j int) bool {
|
||||
if trades[i].ExitDate == trades[j].ExitDate {
|
||||
return trades[i].InstrumentUID < trades[j].InstrumentUID
|
||||
}
|
||||
return trades[i].ExitDate < trades[j].ExitDate
|
||||
})
|
||||
return Result{
|
||||
Metrics: ComputeMetrics(points, trades),
|
||||
Trades: trades,
|
||||
Equity: points,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e Engine) minuteExecution(c candidate, minutes []domain.Candle, requestedLots int64) (int64, decimal.Decimal, bool) {
|
||||
if requestedLots <= 0 || len(minutes) == 0 {
|
||||
return 0, decimal.Zero, false
|
||||
}
|
||||
entryLots, entryCapacity := e.fillableMinuteLots(minutes, c.entry.TradeDate, c.buy, domain.SideBuy)
|
||||
exitLots, exitCapacity := e.fillableMinuteLots(minutes, c.exit.TradeDate, c.sell, domain.SideSell)
|
||||
lots := min(requestedLots, entryLots)
|
||||
lots = min(lots, exitLots)
|
||||
if lots <= 0 {
|
||||
return 0, decimal.Zero, false
|
||||
}
|
||||
return lots, money.Min(entryCapacity, exitCapacity), true
|
||||
}
|
||||
|
||||
func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limitPrice decimal.Decimal, side domain.Side) (int64, decimal.Decimal) {
|
||||
if !limitPrice.IsPositive() || e.cfg.Lot <= 0 {
|
||||
return 0, decimal.Zero
|
||||
}
|
||||
lotNotional := limitPrice.Mul(decimal.NewFromInt(e.cfg.Lot))
|
||||
if !lotNotional.IsPositive() {
|
||||
return 0, decimal.Zero
|
||||
}
|
||||
capacity := decimal.Zero
|
||||
for _, candle := range minutes {
|
||||
if !sameDate(candle.TradeDate, date) {
|
||||
continue
|
||||
}
|
||||
reachable := side == domain.SideBuy && candle.Low.LessThanOrEqual(limitPrice)
|
||||
reachable = reachable || side == domain.SideSell && candle.High.GreaterThanOrEqual(limitPrice)
|
||||
if !reachable {
|
||||
continue
|
||||
}
|
||||
minuteCapacity := candle.VolumeLots.Mul(lotNotional).Mul(e.cfg.MaxParticipationRate)
|
||||
capacity = capacity.Add(minuteCapacity)
|
||||
}
|
||||
return capacity.Div(lotNotional).Floor().IntPart(), capacity
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
instrumentUID string
|
||||
entry domain.Candle
|
||||
exit domain.Candle
|
||||
buy decimal.Decimal
|
||||
sell decimal.Decimal
|
||||
netEdge decimal.Decimal
|
||||
adv decimal.Decimal
|
||||
q05Abs decimal.Decimal
|
||||
overnightGap decimal.Decimal
|
||||
capacity decimal.Decimal
|
||||
}
|
||||
|
||||
func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, exitIndex int) (candidate, bool, error) {
|
||||
if exitIndex < e.cfg.RollingShort || exitIndex <= 0 {
|
||||
return candidate{}, false, nil
|
||||
}
|
||||
history := candles[:exitIndex]
|
||||
returns := make([]float64, 0, len(history)-1)
|
||||
for j := 1; j < len(history); j++ {
|
||||
r, err := features.OvernightReturn(history[j].Open, history[j-1].Close)
|
||||
if err != nil {
|
||||
return candidate{}, false, err
|
||||
}
|
||||
rf, _ := r.Float64()
|
||||
returns = append(returns, rf)
|
||||
}
|
||||
short := features.Rolling(returns, e.cfg.RollingShort, e.cfg.EWMALambda)
|
||||
long := features.Rolling(returns, min(e.cfg.RollingLong, len(returns)), e.cfg.EWMALambda)
|
||||
if !short.Available || !long.Available || short.StdDev == 0 {
|
||||
return candidate{}, false, nil
|
||||
}
|
||||
rawEdge := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
|
||||
cost := e.cfg.AssumedSpreadBps.
|
||||
Add(e.cfg.EntrySlippageBps).
|
||||
Add(e.cfg.ExitSlippageBps).
|
||||
Add(e.cfg.CommissionRoundtripBps)
|
||||
netEdge := rawEdge.Sub(cost)
|
||||
adv := features.ADV(history, e.cfg.Lot, 20)
|
||||
switch {
|
||||
case e.cfg.RequireZeroCommission && e.cfg.CommissionRoundtripBps.IsPositive():
|
||||
return candidate{}, false, nil
|
||||
case !decimal.NewFromFloat(short.Mean).IsPositive() || !decimal.NewFromFloat(long.Mean).IsPositive():
|
||||
return candidate{}, false, nil
|
||||
case decimal.NewFromFloat(short.TStat).LessThan(e.cfg.MinTStat60):
|
||||
return candidate{}, false, nil
|
||||
case decimal.NewFromFloat(short.WinRate).LessThan(e.cfg.MinWinRate60):
|
||||
return candidate{}, false, nil
|
||||
case netEdge.LessThan(e.cfg.MinNetEdgeBps):
|
||||
return candidate{}, false, nil
|
||||
case e.cfg.AssumedSpreadBps.GreaterThan(e.cfg.MaxSpreadBps):
|
||||
return candidate{}, false, nil
|
||||
case e.cfg.AssumedTickBps.GreaterThan(e.cfg.MaxTickBps):
|
||||
return candidate{}, false, nil
|
||||
case adv.LessThan(e.cfg.MinADVRUB):
|
||||
return candidate{}, false, nil
|
||||
}
|
||||
entry := candles[exitIndex-1]
|
||||
exit := candles[exitIndex]
|
||||
buy := entry.Close.Mul(decimal.NewFromInt(1).Add(money.FromBps(e.cfg.EntrySlippageBps)))
|
||||
sell := exit.Open.Mul(decimal.NewFromInt(1).Sub(money.FromBps(e.cfg.ExitSlippageBps)))
|
||||
gap, err := features.OvernightReturn(exit.Open, entry.Close)
|
||||
if err != nil {
|
||||
return candidate{}, false, err
|
||||
}
|
||||
q05Abs := decimal.NewFromFloat(features.Quantile(returns, 0.05))
|
||||
if q05Abs.IsNegative() {
|
||||
q05Abs = q05Abs.Neg()
|
||||
}
|
||||
return candidate{
|
||||
instrumentUID: instrumentUID,
|
||||
entry: entry,
|
||||
exit: exit,
|
||||
buy: buy,
|
||||
sell: sell,
|
||||
netEdge: netEdge,
|
||||
adv: adv,
|
||||
q05Abs: q05Abs,
|
||||
overnightGap: gap,
|
||||
capacity: adv.Mul(e.cfg.MaxParticipationRate),
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func prepareCandles(candlesByInstrument map[string][]domain.Candle) map[string][]domain.Candle {
|
||||
prepared := make(map[string][]domain.Candle, len(candlesByInstrument))
|
||||
for instrumentUID, candles := range candlesByInstrument {
|
||||
cp := append([]domain.Candle(nil), candles...)
|
||||
sort.Slice(cp, func(i, j int) bool {
|
||||
return cp[i].TradeDate.Before(cp[j].TradeDate)
|
||||
})
|
||||
prepared[instrumentUID] = cp
|
||||
}
|
||||
return prepared
|
||||
}
|
||||
|
||||
func (r Result) Write(outputDir string) error {
|
||||
if outputDir == "" {
|
||||
outputDir = "./backtest_out"
|
||||
}
|
||||
if err := os.MkdirAll(outputDir, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
summary, err := json.MarshalIndent(r.Metrics, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(outputDir, "summary.json"), summary, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeTrades(filepath.Join(outputDir, "trades.csv"), r.Trades); err != nil {
|
||||
return err
|
||||
}
|
||||
return writeEquity(filepath.Join(outputDir, "equity.csv"), r.Equity)
|
||||
}
|
||||
|
||||
func LoadCandlesCSV(reader io.Reader) (map[string][]domain.Candle, error) {
|
||||
r := csv.NewReader(reader)
|
||||
r.FieldsPerRecord = -1
|
||||
records, err := r.ReadAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make(map[string][]domain.Candle)
|
||||
for i, record := range records {
|
||||
if i == 0 && len(record) > 0 && record[0] == "instrument_uid" {
|
||||
continue
|
||||
}
|
||||
if len(record) < 7 {
|
||||
return nil, fmt.Errorf("line %d: expected 7 fields", i+1)
|
||||
}
|
||||
date, err := parseCandleTime(record[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
open, err := decimal.NewFromString(record[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
high, err := decimal.NewFromString(record[3])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
low, err := decimal.NewFromString(record[4])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
closePrice, err := decimal.NewFromString(record[5])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
volume, err := decimal.NewFromString(record[6])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
candle := domain.Candle{
|
||||
InstrumentUID: record[0],
|
||||
TradeDate: date,
|
||||
Open: open,
|
||||
High: high,
|
||||
Low: low,
|
||||
Close: closePrice,
|
||||
VolumeLots: volume,
|
||||
Source: "csv",
|
||||
LoadedAt: time.Now().UTC(),
|
||||
}
|
||||
out[candle.InstrumentUID] = append(out[candle.InstrumentUID], candle)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func parseCandleTime(raw string) (time.Time, error) {
|
||||
layouts := []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02",
|
||||
}
|
||||
var lastErr error
|
||||
for _, layout := range layouts {
|
||||
parsed, err := time.Parse(layout, raw)
|
||||
if err == nil {
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return time.Time{}, lastErr
|
||||
}
|
||||
|
||||
func sameDate(a, b time.Time) bool {
|
||||
return dateOnly(a).Equal(dateOnly(b))
|
||||
}
|
||||
|
||||
func dateOnly(t time.Time) time.Time {
|
||||
y, m, d := t.UTC().Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func writeTrades(path string, trades []Trade) error {
|
||||
// #nosec G304 -- path is the user-selected backtest output destination.
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
w := csv.NewWriter(f)
|
||||
defer w.Flush()
|
||||
if err := w.Write([]string{"instrument_uid", "entry_date", "exit_date", "buy_price", "sell_price", "return", "lots", "notional", "net_pnl", "spread_bps", "slippage_bps", "overnight_gap", "capacity_rub"}); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, trade := range trades {
|
||||
if err := w.Write([]string{
|
||||
trade.InstrumentUID,
|
||||
trade.EntryDate,
|
||||
trade.ExitDate,
|
||||
trade.BuyPrice.String(),
|
||||
trade.SellPrice.String(),
|
||||
trade.Return.String(),
|
||||
fmt.Sprintf("%d", trade.Lots),
|
||||
trade.Notional.String(),
|
||||
trade.NetPnL.String(),
|
||||
trade.SpreadBps.String(),
|
||||
trade.SlippageBps.String(),
|
||||
trade.OvernightGap.String(),
|
||||
trade.CapacityRUB.String(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return w.Error()
|
||||
}
|
||||
|
||||
func writeEquity(path string, points []Point) error {
|
||||
// #nosec G304 -- path is the user-selected backtest output destination.
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
w := csv.NewWriter(f)
|
||||
defer w.Flush()
|
||||
if err := w.Write([]string{"date", "equity", "return"}); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, point := range points {
|
||||
if err := w.Write([]string{point.Date, point.Equity.String(), point.Return.String()}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return w.Error()
|
||||
}
|
||||
|
||||
func ParseDecimalFlag(raw string) (decimal.Decimal, error) {
|
||||
if raw == "" {
|
||||
return decimal.Zero, nil
|
||||
}
|
||||
return decimal.NewFromString(raw)
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package backtest
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func TestBacktestNoLookAheadWithFutureOnlyEdge(t *testing.T) {
|
||||
var candles []domain.Candle
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
for i := 0; i < 80; i++ {
|
||||
open := decimal.NewFromInt(100)
|
||||
if i == 79 {
|
||||
open = decimal.NewFromInt(110)
|
||||
}
|
||||
candles = append(candles, domain.Candle{
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: start.AddDate(0, 0, i),
|
||||
Open: open,
|
||||
High: open,
|
||||
Low: open,
|
||||
Close: decimal.NewFromInt(100),
|
||||
})
|
||||
}
|
||||
result, err := New(Config{}).Run(map[string][]domain.Candle{"uid": candles})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(result.Trades) != 0 {
|
||||
t.Fatalf("future-only edge leaked into signals: %d trades", len(result.Trades))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinuteExecutionRequiresReachableLimitAndParticipation(t *testing.T) {
|
||||
engine := New(Config{
|
||||
Lot: 10,
|
||||
MaxParticipationRate: decimal.NewFromFloat(0.10),
|
||||
})
|
||||
entryDate := time.Date(2024, 1, 2, 18, 25, 0, 0, time.UTC)
|
||||
exitDate := time.Date(2024, 1, 3, 10, 5, 0, 0, time.UTC)
|
||||
c := candidate{
|
||||
instrumentUID: "uid",
|
||||
entry: domain.Candle{TradeDate: entryDate},
|
||||
exit: domain.Candle{TradeDate: exitDate},
|
||||
buy: decimal.NewFromInt(100),
|
||||
sell: decimal.NewFromInt(105),
|
||||
}
|
||||
minutes := []domain.Candle{
|
||||
{TradeDate: entryDate, Low: decimal.NewFromInt(99), High: decimal.NewFromInt(101), VolumeLots: decimal.NewFromInt(20)},
|
||||
{TradeDate: exitDate, Low: decimal.NewFromInt(104), High: decimal.NewFromInt(106), VolumeLots: decimal.NewFromInt(20)},
|
||||
}
|
||||
lots, capacity, ok := engine.minuteExecution(c, minutes, 5)
|
||||
if !ok {
|
||||
t.Fatal("expected minute execution")
|
||||
}
|
||||
if lots != 2 {
|
||||
t.Fatalf("lots=%d, want 2", lots)
|
||||
}
|
||||
if !capacity.Equal(decimal.NewFromInt(2000)) {
|
||||
t.Fatalf("capacity=%s, want 2000", capacity)
|
||||
}
|
||||
c.sell = decimal.NewFromInt(110)
|
||||
if _, _, ok := engine.minuteExecution(c, minutes, 5); ok {
|
||||
t.Fatal("sell limit should be unreachable")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
package backtest
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
TotalReturn float64 `json:"total_return"`
|
||||
CAGR float64 `json:"cagr"`
|
||||
AnnualizedVolatility float64 `json:"annualized_volatility"`
|
||||
SharpeRatio float64 `json:"sharpe_ratio"`
|
||||
SortinoRatio float64 `json:"sortino_ratio"`
|
||||
MaxDrawdown float64 `json:"max_drawdown"`
|
||||
CalmarRatio float64 `json:"calmar_ratio"`
|
||||
WinRate float64 `json:"win_rate"`
|
||||
AverageTradeReturn float64 `json:"average_trade_return"`
|
||||
MedianTradeReturn float64 `json:"median_trade_return"`
|
||||
ProfitFactor float64 `json:"profit_factor"`
|
||||
AverageSpreadBps float64 `json:"average_spread_bps"`
|
||||
AverageSlippageBps float64 `json:"average_slippage_bps"`
|
||||
NumberOfTrades int `json:"number_of_trades"`
|
||||
WorstOvernightGap float64 `json:"worst_overnight_gap"`
|
||||
VaR95 float64 `json:"var_95"`
|
||||
CVaR95 float64 `json:"cvar_95"`
|
||||
CapacityEstimate float64 `json:"capacity_estimate"`
|
||||
}
|
||||
|
||||
func ComputeMetrics(points []Point, trades []Trade) Metrics {
|
||||
if len(points) == 0 {
|
||||
return Metrics{}
|
||||
}
|
||||
start, _ := points[0].Equity.Float64()
|
||||
end, _ := points[len(points)-1].Equity.Float64()
|
||||
returns := make([]float64, 0, len(points)-1)
|
||||
for _, point := range points[1:] {
|
||||
r, _ := point.Return.Float64()
|
||||
returns = append(returns, r)
|
||||
}
|
||||
tradeReturns := make([]float64, 0, len(trades))
|
||||
spreads := make([]float64, 0, len(trades))
|
||||
slippages := make([]float64, 0, len(trades))
|
||||
profits := 0.0
|
||||
losses := 0.0
|
||||
wins := 0
|
||||
worstGap := 0.0
|
||||
capacity := 0.0
|
||||
for _, trade := range trades {
|
||||
r, _ := trade.Return.Float64()
|
||||
tradeReturns = append(tradeReturns, r)
|
||||
spread, _ := trade.SpreadBps.Float64()
|
||||
spreads = append(spreads, spread)
|
||||
slippage, _ := trade.SlippageBps.Float64()
|
||||
slippages = append(slippages, slippage)
|
||||
if r > 0 {
|
||||
wins++
|
||||
profits += r
|
||||
} else {
|
||||
losses += r
|
||||
}
|
||||
gap, _ := trade.OvernightGap.Float64()
|
||||
if gap < worstGap {
|
||||
worstGap = gap
|
||||
}
|
||||
tradeCapacity, _ := trade.CapacityRUB.Float64()
|
||||
if tradeCapacity > 0 && (capacity == 0 || tradeCapacity < capacity) {
|
||||
capacity = tradeCapacity
|
||||
}
|
||||
}
|
||||
totalReturn := 0.0
|
||||
if start > 0 {
|
||||
totalReturn = end/start - 1
|
||||
}
|
||||
vol := stddev(returns) * math.Sqrt(252)
|
||||
meanReturn := mean(returns)
|
||||
sharpe := 0.0
|
||||
if std := stddev(returns); std > 0 {
|
||||
sharpe = meanReturn / std * math.Sqrt(252)
|
||||
}
|
||||
sortino := 0.0
|
||||
if down := downsideStddev(returns); down > 0 {
|
||||
sortino = meanReturn / down * math.Sqrt(252)
|
||||
}
|
||||
tradingDays := math.Max(float64(len(returns)), 1)
|
||||
cagr := 0.0
|
||||
if start > 0 && end > 0 {
|
||||
cagr = math.Pow(end/start, 252/tradingDays) - 1
|
||||
}
|
||||
maxDD := maxDrawdown(points)
|
||||
calmar := 0.0
|
||||
if maxDD != 0 {
|
||||
calmar = cagr / math.Abs(maxDD)
|
||||
}
|
||||
pf := 0.0
|
||||
if losses != 0 {
|
||||
pf = profits / math.Abs(losses)
|
||||
}
|
||||
var95 := percentile(returns, 0.05)
|
||||
cvar95 := conditionalMean(returns, var95)
|
||||
return Metrics{
|
||||
TotalReturn: totalReturn,
|
||||
CAGR: cagr,
|
||||
AnnualizedVolatility: vol,
|
||||
SharpeRatio: sharpe,
|
||||
SortinoRatio: sortino,
|
||||
MaxDrawdown: maxDD,
|
||||
CalmarRatio: calmar,
|
||||
WinRate: ratio(wins, len(tradeReturns)),
|
||||
AverageTradeReturn: mean(tradeReturns),
|
||||
MedianTradeReturn: percentile(tradeReturns, 0.50),
|
||||
ProfitFactor: pf,
|
||||
AverageSpreadBps: mean(spreads),
|
||||
AverageSlippageBps: mean(slippages),
|
||||
NumberOfTrades: len(trades),
|
||||
WorstOvernightGap: worstGap,
|
||||
VaR95: var95,
|
||||
CVaR95: cvar95,
|
||||
CapacityEstimate: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
func maxDrawdown(points []Point) float64 {
|
||||
peak := 0.0
|
||||
maxDD := 0.0
|
||||
for _, point := range points {
|
||||
e, _ := point.Equity.Float64()
|
||||
if e > peak {
|
||||
peak = e
|
||||
}
|
||||
if peak > 0 {
|
||||
dd := e/peak - 1
|
||||
if dd < maxDD {
|
||||
maxDD = dd
|
||||
}
|
||||
}
|
||||
}
|
||||
return maxDD
|
||||
}
|
||||
|
||||
func mean(values []float64) float64 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
sum := 0.0
|
||||
for _, value := range values {
|
||||
sum += value
|
||||
}
|
||||
return sum / float64(len(values))
|
||||
}
|
||||
|
||||
func stddev(values []float64) float64 {
|
||||
if len(values) < 2 {
|
||||
return 0
|
||||
}
|
||||
m := mean(values)
|
||||
sum := 0.0
|
||||
for _, value := range values {
|
||||
diff := value - m
|
||||
sum += diff * diff
|
||||
}
|
||||
return math.Sqrt(sum / float64(len(values)-1))
|
||||
}
|
||||
|
||||
func downsideStddev(values []float64) float64 {
|
||||
var downs []float64
|
||||
for _, value := range values {
|
||||
if value < 0 {
|
||||
downs = append(downs, value)
|
||||
}
|
||||
}
|
||||
return stddev(downs)
|
||||
}
|
||||
|
||||
func percentile(values []float64, q float64) float64 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
cp := append([]float64(nil), values...)
|
||||
sort.Float64s(cp)
|
||||
pos := q * float64(len(cp)-1)
|
||||
lower := int(math.Floor(pos))
|
||||
upper := int(math.Ceil(pos))
|
||||
if lower == upper {
|
||||
return cp[lower]
|
||||
}
|
||||
weight := pos - float64(lower)
|
||||
return cp[lower]*(1-weight) + cp[upper]*weight
|
||||
}
|
||||
|
||||
func conditionalMean(values []float64, threshold float64) float64 {
|
||||
var selected []float64
|
||||
for _, value := range values {
|
||||
if value <= threshold {
|
||||
selected = append(selected, value)
|
||||
}
|
||||
}
|
||||
return mean(selected)
|
||||
}
|
||||
|
||||
func ratio(n, d int) float64 {
|
||||
if d == 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(n) / float64(d)
|
||||
}
|
||||
|
||||
func point(date string, equity, ret string) Point {
|
||||
e, _ := decimal.NewFromString(equity)
|
||||
r, _ := decimal.NewFromString(ret)
|
||||
return Point{Date: date, Equity: e, Return: r}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package backtest
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMetrics(t *testing.T) {
|
||||
got := ComputeMetrics([]Point{
|
||||
point("START", "100", "0"),
|
||||
point("2024-01-02", "110", "0.10"),
|
||||
point("2024-01-03", "99", "-0.10"),
|
||||
}, []Trade{{Return: point("", "0", "0.10").Return}, {Return: point("", "0", "-0.10").Return}})
|
||||
if got.NumberOfTrades != 2 || got.WinRate != 0.5 || got.MaxDrawdown >= 0 || got.VaR95 >= 0 {
|
||||
t.Fatalf("unexpected metrics: %+v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/timeutil"
|
||||
)
|
||||
|
||||
const liveTradeAck = "I_ACCEPT_RISK"
|
||||
const maxQuoteDepth = 50
|
||||
|
||||
type Config struct {
|
||||
App AppConfig `envPrefix:"APP_"`
|
||||
TInvest TInvestConfig `envPrefix:"TINVEST_"`
|
||||
DB DBConfig `envPrefix:"DB_"`
|
||||
Telegram TelegramConfig `envPrefix:"TELEGRAM_"`
|
||||
Strategy StrategyConfig `envPrefix:"STRATEGY_"`
|
||||
Execution ExecutionConfig `envPrefix:"EXEC_"`
|
||||
Risk RiskConfig `envPrefix:"RISK_"`
|
||||
Liquidity LiquidityConfig `envPrefix:"LIQ_"`
|
||||
Commission CommissionConfig `envPrefix:"COMM_"`
|
||||
Backtest BacktestConfig `envPrefix:"BT_"`
|
||||
Live LiveConfig `envPrefix:"LIVE_"`
|
||||
|
||||
Location *time.Location `env:"-"`
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
Mode domain.Mode `env:"MODE,required"`
|
||||
Timezone string `env:"TIMEZONE" envDefault:"Europe/Moscow"`
|
||||
LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
|
||||
HealthcheckAddr string `env:"HEALTHCHECK_ADDR" envDefault:":3300"`
|
||||
ShutdownTimeoutSec int `env:"SHUTDOWN_TIMEOUT_SEC" envDefault:"30"`
|
||||
}
|
||||
|
||||
type TInvestConfig struct {
|
||||
Token string `env:"TOKEN"`
|
||||
AccountID string `env:"ACCOUNT_ID"`
|
||||
Endpoint string `env:"ENDPOINT" envDefault:"invest-public-api.tinkoff.ru:443"`
|
||||
AppName string `env:"APP_NAME" envDefault:"overnight-trading-bot"`
|
||||
RequestTimeoutSec int `env:"REQUEST_TIMEOUT_SEC" envDefault:"10"`
|
||||
RetryCount int `env:"RETRY_COUNT" envDefault:"3"`
|
||||
RetryBackoffSec int `env:"RETRY_BACKOFF_SEC" envDefault:"2"`
|
||||
UseSandbox bool `env:"USE_SANDBOX" envDefault:"false"`
|
||||
}
|
||||
|
||||
type DBConfig struct {
|
||||
DSN string `env:"DSN"`
|
||||
MaxOpenConns int `env:"MAX_OPEN_CONNS" envDefault:"20"`
|
||||
MaxIdleConns int `env:"MAX_IDLE_CONNS" envDefault:"5"`
|
||||
ConnMaxLifetimeMin int `env:"CONN_MAX_LIFETIME_MIN" envDefault:"30"`
|
||||
MigrationsAutoApply bool `env:"MIGRATIONS_AUTO_APPLY" envDefault:"true"`
|
||||
}
|
||||
|
||||
type TelegramConfig struct {
|
||||
BotToken string `env:"BOT_TOKEN"`
|
||||
ChatID int64 `env:"CHAT_ID"`
|
||||
NotifyInfo bool `env:"NOTIFY_INFO" envDefault:"true"`
|
||||
NotifyWarn bool `env:"NOTIFY_WARN" envDefault:"true"`
|
||||
NotifyAlert bool `env:"NOTIFY_ALERT" envDefault:"true"`
|
||||
NotifyReport bool `env:"NOTIFY_REPORT" envDefault:"true"`
|
||||
}
|
||||
|
||||
type StrategyConfig struct {
|
||||
RollingShort int `env:"ROLLING_SHORT" envDefault:"60"`
|
||||
RollingLong int `env:"ROLLING_LONG" envDefault:"252"`
|
||||
EWMALambda float64 `env:"EWMA_LAMBDA" envDefault:"0.08"`
|
||||
MinTStat60 decimal.Decimal `env:"MIN_TSTAT_60" envDefault:"1.25"`
|
||||
MinWinRate60 decimal.Decimal `env:"MIN_WIN_RATE_60" envDefault:"0.55"`
|
||||
MinNetEdgeBps decimal.Decimal `env:"MIN_NET_EDGE_BPS" envDefault:"10"`
|
||||
RiskBufferBps decimal.Decimal `env:"RISK_BUFFER_BPS" envDefault:"5"`
|
||||
MaxPositions int `env:"MAX_POSITIONS" envDefault:"5"`
|
||||
}
|
||||
|
||||
type ExecutionConfig struct {
|
||||
EntrySignalTime timeutil.TimeOfDay `env:"ENTRY_SIGNAL_TIME" envDefault:"18:10:00"`
|
||||
EntryWindowStart timeutil.TimeOfDay `env:"ENTRY_WINDOW_START" envDefault:"18:20:00"`
|
||||
EntryWindowEnd timeutil.TimeOfDay `env:"ENTRY_WINDOW_END" envDefault:"18:38:30"`
|
||||
NoNewEntryAfter timeutil.TimeOfDay `env:"NO_NEW_ENTRY_AFTER" envDefault:"18:38:30"`
|
||||
ExitWatchStart timeutil.TimeOfDay `env:"EXIT_WATCH_START" envDefault:"09:50:00"`
|
||||
ExitNotBefore timeutil.TimeOfDay `env:"EXIT_NOT_BEFORE" envDefault:"10:03:00"`
|
||||
ExitWindowStart timeutil.TimeOfDay `env:"EXIT_WINDOW_START" envDefault:"10:05:00"`
|
||||
ExitWindowEnd timeutil.TimeOfDay `env:"EXIT_WINDOW_END" envDefault:"10:25:00"`
|
||||
HardExitDeadline timeutil.TimeOfDay `env:"HARD_EXIT_DEADLINE" envDefault:"10:45:00"`
|
||||
MinTimeToCloseSec int `env:"MIN_TIME_TO_CLOSE_SEC" envDefault:"90"`
|
||||
AllowMarketOrders bool `env:"ALLOW_MARKET_ORDERS" envDefault:"false"`
|
||||
MaxEntryOrderAttempts int `env:"MAX_ENTRY_ORDER_ATTEMPTS" envDefault:"3"`
|
||||
MaxExitOrderAttempts int `env:"MAX_EXIT_ORDER_ATTEMPTS" envDefault:"3"`
|
||||
PassiveImproveTicks int `env:"PASSIVE_IMPROVE_TICKS" envDefault:"1"`
|
||||
QuoteDepth int32 `env:"QUOTE_DEPTH" envDefault:"20"`
|
||||
MaxQuoteAgeSec int `env:"MAX_QUOTE_AGE_SEC" envDefault:"3"`
|
||||
OrderPollIntervalMS int `env:"ORDER_POLL_INTERVAL_MS" envDefault:"500"`
|
||||
}
|
||||
|
||||
type RiskConfig struct {
|
||||
UseMargin bool `env:"USE_MARGIN" envDefault:"false"`
|
||||
AllowShort bool `env:"ALLOW_SHORT" envDefault:"false"`
|
||||
MaxTotalExposurePct decimal.Decimal `env:"MAX_TOTAL_EXPOSURE_PCT" envDefault:"0.50"`
|
||||
MaxPositionPct decimal.Decimal `env:"MAX_POSITION_PCT" envDefault:"0.10"`
|
||||
MaxDailyLossPct decimal.Decimal `env:"MAX_DAILY_LOSS_PCT" envDefault:"0.01"`
|
||||
MaxWeeklyLossPct decimal.Decimal `env:"MAX_WEEKLY_LOSS_PCT" envDefault:"0.03"`
|
||||
MaxMonthlyDrawdownPct decimal.Decimal `env:"MAX_MONTHLY_DRAWDOWN_PCT" envDefault:"0.07"`
|
||||
MaxOpenPositions int `env:"MAX_OPEN_POSITIONS" envDefault:"5"`
|
||||
MaxAvgSlippageBps10Trades decimal.Decimal `env:"MAX_AVG_SLIPPAGE_BPS_10_TRADES" envDefault:"15"`
|
||||
APIOutageHaltSec int `env:"API_OUTAGE_HALT_SEC" envDefault:"180"`
|
||||
MaxClockDriftSec int `env:"MAX_CLOCK_DRIFT_SEC" envDefault:"2"`
|
||||
ReconciliationWindowHours int `env:"RECONCILIATION_WINDOW_HOURS" envDefault:"72"`
|
||||
ReconciliationSkewSec int `env:"RECONCILIATION_SKEW_SEC" envDefault:"10"`
|
||||
CashUsageBuffer decimal.Decimal `env:"CASH_USAGE_BUFFER" envDefault:"0.95"`
|
||||
RiskBudgetPerInstrumentPct decimal.Decimal `env:"RISK_BUDGET_PER_INSTRUMENT_PCT" envDefault:"0.005"`
|
||||
MinOrderNotionalRUB decimal.Decimal `env:"MIN_ORDER_NOTIONAL_RUB" envDefault:"1000"`
|
||||
}
|
||||
|
||||
type LiquidityConfig struct {
|
||||
MinADVRUB decimal.Decimal `env:"MIN_ADV_RUB" envDefault:"5000000"`
|
||||
MaxParticipationRate decimal.Decimal `env:"MAX_PARTICIPATION_RATE" envDefault:"0.01"`
|
||||
MaxSpreadBpsDefault decimal.Decimal `env:"MAX_SPREAD_BPS_DEFAULT" envDefault:"20"`
|
||||
MaxSpreadBpsMoneyMarket decimal.Decimal `env:"MAX_SPREAD_BPS_MONEY_MARKET" envDefault:"5"`
|
||||
MaxSpreadBpsBondFunds decimal.Decimal `env:"MAX_SPREAD_BPS_BOND_FUNDS" envDefault:"10"`
|
||||
MaxSpreadBpsEquityFunds decimal.Decimal `env:"MAX_SPREAD_BPS_EQUITY_FUNDS" envDefault:"25"`
|
||||
MaxTickBps decimal.Decimal `env:"MAX_TICK_BPS" envDefault:"10"`
|
||||
}
|
||||
|
||||
type CommissionConfig struct {
|
||||
RequireZeroCommission bool `env:"REQUIRE_ZERO_COMMISSION" envDefault:"true"`
|
||||
QuarantineOnNonZero bool `env:"QUARANTINE_ON_NONZERO" envDefault:"true"`
|
||||
FreeOrderCountPolicy string `env:"FREE_ORDER_COUNT_POLICY" envDefault:"submitted"`
|
||||
}
|
||||
|
||||
type BacktestConfig struct {
|
||||
DateFrom string `env:"DATE_FROM"`
|
||||
DateTo string `env:"DATE_TO"`
|
||||
EntrySlippageBps decimal.Decimal `env:"ENTRY_SLIPPAGE_BPS" envDefault:"8"`
|
||||
ExitSlippageBps decimal.Decimal `env:"EXIT_SLIPPAGE_BPS" envDefault:"8"`
|
||||
CommissionRoundtripBps decimal.Decimal `env:"COMMISSION_ROUNDTRIP_BPS" envDefault:"0"`
|
||||
UseMinuteModel bool `env:"USE_MINUTE_MODEL" envDefault:"false"`
|
||||
OutputDir string `env:"OUTPUT_DIR" envDefault:"./backtest_out"`
|
||||
}
|
||||
|
||||
type LiveConfig struct {
|
||||
TradeAck string `env:"TRADE_ACK"`
|
||||
}
|
||||
|
||||
func Load() (Config, error) {
|
||||
var cfg Config
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.App.Mode == "" {
|
||||
return errors.New("APP_MODE is required")
|
||||
}
|
||||
loc, err := time.LoadLocation(c.App.Timezone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load timezone %q: %w", c.App.Timezone, err)
|
||||
}
|
||||
if c.App.Timezone != "Europe/Moscow" {
|
||||
return fmt.Errorf("APP_TIMEZONE must be Europe/Moscow, got %q", c.App.Timezone)
|
||||
}
|
||||
c.Location = loc
|
||||
|
||||
if c.App.ShutdownTimeoutSec <= 0 {
|
||||
return errors.New("APP_SHUTDOWN_TIMEOUT_SEC must be positive")
|
||||
}
|
||||
if c.Execution.AllowMarketOrders {
|
||||
return errors.New("EXEC_ALLOW_MARKET_ORDERS must remain false: strategy is LIMIT-only")
|
||||
}
|
||||
if c.Execution.QuoteDepth <= 0 || c.Execution.QuoteDepth > maxQuoteDepth {
|
||||
return fmt.Errorf("EXEC_QUOTE_DEPTH must be between 1 and %d", maxQuoteDepth)
|
||||
}
|
||||
if c.Execution.OrderPollIntervalMS <= 0 {
|
||||
return errors.New("EXEC_ORDER_POLL_INTERVAL_MS must be positive")
|
||||
}
|
||||
if c.Risk.UseMargin {
|
||||
return errors.New("RISK_USE_MARGIN must remain false")
|
||||
}
|
||||
if c.Risk.AllowShort {
|
||||
return errors.New("RISK_ALLOW_SHORT must remain false")
|
||||
}
|
||||
if c.Risk.APIOutageHaltSec <= 0 {
|
||||
return errors.New("RISK_API_OUTAGE_HALT_SEC must be positive")
|
||||
}
|
||||
if c.Risk.ReconciliationWindowHours <= 0 {
|
||||
return errors.New("RISK_RECONCILIATION_WINDOW_HOURS must be positive")
|
||||
}
|
||||
if c.Risk.ReconciliationSkewSec < 0 {
|
||||
return errors.New("RISK_RECONCILIATION_SKEW_SEC must be non-negative")
|
||||
}
|
||||
if c.Commission.FreeOrderCountPolicy != "submitted" {
|
||||
return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted, got %q", c.Commission.FreeOrderCountPolicy)
|
||||
}
|
||||
if err := c.validateWindows(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.App.Mode != domain.ModeBacktest && c.DB.DSN == "" {
|
||||
return errors.New("DB_DSN is required outside backtest mode")
|
||||
}
|
||||
if (c.App.Mode == domain.ModeSandbox || c.App.Mode == domain.ModeLiveReadonly || c.App.Mode == domain.ModeLiveTrade) && c.TInvest.Token == "" {
|
||||
return fmt.Errorf("TINVEST_TOKEN is required for APP_MODE=%s", c.App.Mode)
|
||||
}
|
||||
if c.TInvest.UseSandbox && c.App.Mode != domain.ModeSandbox {
|
||||
return errors.New("TINVEST_USE_SANDBOX=true is only valid with APP_MODE=sandbox")
|
||||
}
|
||||
if c.App.Mode == domain.ModeLiveTrade && c.Live.TradeAck != liveTradeAck {
|
||||
return fmt.Errorf("LIVE_TRADE_ACK=%s is required for APP_MODE=live_trade", liveTradeAck)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) validateWindows() error {
|
||||
if c.Execution.EntryWindowStart.Duration >= c.Execution.EntryWindowEnd.Duration ||
|
||||
c.Execution.EntryWindowEnd.Duration > c.Execution.NoNewEntryAfter.Duration {
|
||||
return errors.New("entry windows must satisfy EXEC_ENTRY_WINDOW_START < EXEC_ENTRY_WINDOW_END <= EXEC_NO_NEW_ENTRY_AFTER")
|
||||
}
|
||||
if c.Execution.ExitWatchStart.Duration > c.Execution.ExitNotBefore.Duration ||
|
||||
c.Execution.ExitNotBefore.Duration > c.Execution.ExitWindowStart.Duration ||
|
||||
c.Execution.ExitWindowStart.Duration >= c.Execution.ExitWindowEnd.Duration ||
|
||||
c.Execution.ExitWindowEnd.Duration > c.Execution.HardExitDeadline.Duration {
|
||||
return errors.New("exit windows must be monotonic from EXEC_EXIT_WATCH_START to EXEC_HARD_EXIT_DEADLINE")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeBacktest Mode = "backtest"
|
||||
ModePaper Mode = "paper"
|
||||
ModeSandbox Mode = "sandbox"
|
||||
ModeLiveReadonly Mode = "live_readonly"
|
||||
ModeLiveTrade Mode = "live_trade"
|
||||
)
|
||||
|
||||
func ParseMode(raw string) (Mode, error) {
|
||||
mode := Mode(strings.TrimSpace(raw))
|
||||
switch mode {
|
||||
case ModeBacktest, ModePaper, ModeSandbox, ModeLiveReadonly, ModeLiveTrade:
|
||||
return mode, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported app mode %q", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func (m Mode) AllowsBrokerOrders() bool {
|
||||
return m == ModeSandbox || m == ModeLiveTrade
|
||||
}
|
||||
|
||||
func (m *Mode) UnmarshalText(text []byte) error {
|
||||
mode, err := ParseMode(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*m = mode
|
||||
return nil
|
||||
}
|
||||
|
||||
type Side string
|
||||
|
||||
const (
|
||||
SideBuy Side = "BUY"
|
||||
SideSell Side = "SELL"
|
||||
)
|
||||
|
||||
type OrderType string
|
||||
|
||||
const (
|
||||
OrderTypeLimit OrderType = "LIMIT"
|
||||
)
|
||||
|
||||
type OrderStatus string
|
||||
|
||||
const (
|
||||
OrderStatusNew OrderStatus = "NEW"
|
||||
OrderStatusSent OrderStatus = "SENT"
|
||||
OrderStatusPartiallyFilled OrderStatus = "PARTIALLY_FILLED"
|
||||
OrderStatusFilled OrderStatus = "FILLED"
|
||||
OrderStatusCancelled OrderStatus = "CANCELLED"
|
||||
OrderStatusRejected OrderStatus = "REJECTED"
|
||||
OrderStatusExpired OrderStatus = "EXPIRED"
|
||||
OrderStatusFailed OrderStatus = "FAILED"
|
||||
)
|
||||
|
||||
type SignalDecision string
|
||||
|
||||
const (
|
||||
DecisionEnter SignalDecision = "ENTER"
|
||||
DecisionSkip SignalDecision = "SKIP"
|
||||
DecisionReject SignalDecision = "REJECT"
|
||||
)
|
||||
|
||||
type PositionStatus string
|
||||
|
||||
const (
|
||||
PositionNoPosition PositionStatus = "NO_POSITION"
|
||||
PositionEntrySignalled PositionStatus = "ENTRY_SIGNALLED"
|
||||
PositionEntryOrderSent PositionStatus = "ENTRY_ORDER_SENT"
|
||||
PositionEntryPartiallyFilled PositionStatus = "ENTRY_PARTIALLY_FILLED"
|
||||
PositionEntryFilled PositionStatus = "ENTRY_FILLED"
|
||||
PositionHoldingOvernight PositionStatus = "HOLDING_OVERNIGHT"
|
||||
PositionExitOrderSent PositionStatus = "EXIT_ORDER_SENT"
|
||||
PositionExitPartiallyFilled PositionStatus = "EXIT_PARTIALLY_FILLED"
|
||||
PositionExitFilled PositionStatus = "EXIT_FILLED"
|
||||
PositionExitFailed PositionStatus = "EXIT_FAILED"
|
||||
PositionQuarantine PositionStatus = "QUARANTINE"
|
||||
)
|
||||
|
||||
type SystemState string
|
||||
|
||||
const (
|
||||
StateInit SystemState = "INIT"
|
||||
StateSyncInstruments SystemState = "SYNC_INSTRUMENTS"
|
||||
StateSyncMarketData SystemState = "SYNC_MARKET_DATA"
|
||||
StateGenerateSignals SystemState = "GENERATE_SIGNALS"
|
||||
StateWaitEntryWindow SystemState = "WAIT_ENTRY_WINDOW"
|
||||
StatePlaceEntryOrders SystemState = "PLACE_ENTRY_ORDERS"
|
||||
StateMonitorEntryOrders SystemState = "MONITOR_ENTRY_ORDERS"
|
||||
StateHoldOvernight SystemState = "HOLD_OVERNIGHT"
|
||||
StateWaitExitWindow SystemState = "WAIT_EXIT_WINDOW"
|
||||
StatePlaceExitOrders SystemState = "PLACE_EXIT_ORDERS"
|
||||
StateMonitorExitOrders SystemState = "MONITOR_EXIT_ORDERS"
|
||||
StateReconcile SystemState = "RECONCILE"
|
||||
StateReport SystemState = "REPORT"
|
||||
StateSleep SystemState = "SLEEP"
|
||||
StateHalted SystemState = "HALTED"
|
||||
)
|
||||
|
||||
type Severity string
|
||||
|
||||
const (
|
||||
SeverityInfo Severity = "INFO"
|
||||
SeverityWarn Severity = "WARN"
|
||||
SeverityAlert Severity = "ALERT"
|
||||
SeverityCritical Severity = "CRITICAL"
|
||||
)
|
||||
|
||||
type TradingStatus string
|
||||
|
||||
const (
|
||||
TradingStatusNormal TradingStatus = "NORMAL_TRADING"
|
||||
TradingStatusClosed TradingStatus = "CLOSED"
|
||||
TradingStatusUnknown TradingStatus = "UNKNOWN"
|
||||
)
|
||||
|
||||
type Instrument struct {
|
||||
InstrumentUID string
|
||||
Figi string
|
||||
Ticker string
|
||||
ClassCode string
|
||||
Name string
|
||||
Lot int64
|
||||
MinPriceIncrement decimal.Decimal
|
||||
Currency string
|
||||
Enabled bool
|
||||
FundType string
|
||||
ExpectedCommissionBpsPerSide decimal.Decimal
|
||||
FreeOrderLimitPerDay int
|
||||
Quarantine bool
|
||||
QuarantineReason string
|
||||
ExcludeReason string
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (i Instrument) MetadataValid() bool {
|
||||
return i.InstrumentUID != "" &&
|
||||
!strings.HasPrefix(i.InstrumentUID, "PENDING:") &&
|
||||
i.Lot > 0 &&
|
||||
i.MinPriceIncrement.IsPositive() &&
|
||||
strings.EqualFold(i.Currency, "RUB")
|
||||
}
|
||||
|
||||
type Candle struct {
|
||||
InstrumentUID string
|
||||
TradeDate time.Time
|
||||
Open decimal.Decimal
|
||||
High decimal.Decimal
|
||||
Low decimal.Decimal
|
||||
Close decimal.Decimal
|
||||
VolumeLots decimal.Decimal
|
||||
Source string
|
||||
LoadedAt time.Time
|
||||
}
|
||||
|
||||
type FeatureSet struct {
|
||||
InstrumentUID string
|
||||
TradeDate time.Time
|
||||
ROn decimal.Decimal
|
||||
RDay decimal.Decimal
|
||||
MuOn60 decimal.Decimal
|
||||
MuOn252 decimal.Decimal
|
||||
SigmaOn60 decimal.Decimal
|
||||
TStatOn60 decimal.Decimal
|
||||
WinOn60 decimal.Decimal
|
||||
EWMAOn decimal.Decimal
|
||||
SpreadBps decimal.Decimal
|
||||
HalfSpreadBps decimal.Decimal
|
||||
TickBps decimal.Decimal
|
||||
ADV20 decimal.Decimal
|
||||
ExpectedCostBps decimal.Decimal
|
||||
NetEdgeBps decimal.Decimal
|
||||
EntryIntervalVolume decimal.Decimal
|
||||
ExitIntervalVolume decimal.Decimal
|
||||
CalculatedAt time.Time
|
||||
}
|
||||
|
||||
type Signal struct {
|
||||
ID int64
|
||||
TradeDate time.Time
|
||||
InstrumentUID string
|
||||
Decision SignalDecision
|
||||
Score decimal.Decimal
|
||||
NetEdgeBps decimal.Decimal
|
||||
TargetNotional decimal.Decimal
|
||||
TargetLots int64
|
||||
RejectReason string
|
||||
ContextJSON string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Order struct {
|
||||
ClientOrderID string
|
||||
BrokerOrderID string
|
||||
AccountIDHash string
|
||||
InstrumentUID string
|
||||
TradeDate time.Time
|
||||
Side Side
|
||||
OrderType OrderType
|
||||
LimitPrice decimal.Decimal
|
||||
QuantityLots int64
|
||||
FilledLots int64
|
||||
AvgFillPrice decimal.Decimal
|
||||
Status OrderStatus
|
||||
Commission decimal.Decimal
|
||||
AttemptNo int
|
||||
RawStateJSON string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Position struct {
|
||||
ID int64
|
||||
AccountIDHash string
|
||||
InstrumentUID string
|
||||
OpenTradeDate time.Time
|
||||
Lots int64
|
||||
Lot int64
|
||||
ExitFilledLots int64
|
||||
AvgBuyPrice decimal.Decimal
|
||||
AvgSellPrice decimal.Decimal
|
||||
Status PositionStatus
|
||||
GrossPnL decimal.Decimal
|
||||
NetPnL decimal.Decimal
|
||||
CommissionTotal decimal.Decimal
|
||||
RealizedEdgeBps decimal.Decimal
|
||||
OpenedAt *time.Time
|
||||
ClosedAt *time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type RiskEvent struct {
|
||||
ID int64
|
||||
TS time.Time
|
||||
Severity Severity
|
||||
EventType string
|
||||
InstrumentUID string
|
||||
Message string
|
||||
ContextJSON string
|
||||
}
|
||||
|
||||
type Holding struct {
|
||||
InstrumentUID string
|
||||
QuantityLots int64
|
||||
AveragePrice decimal.Decimal
|
||||
MarketValue decimal.Decimal
|
||||
}
|
||||
|
||||
type Portfolio struct {
|
||||
Equity decimal.Decimal
|
||||
Cash decimal.Decimal
|
||||
Holdings []Holding
|
||||
CheckedAt time.Time
|
||||
}
|
||||
|
||||
type OrderBookLevel struct {
|
||||
Price decimal.Decimal
|
||||
QuantityLots int64
|
||||
}
|
||||
|
||||
type OrderBook struct {
|
||||
InstrumentUID string
|
||||
Bids []OrderBookLevel
|
||||
Asks []OrderBookLevel
|
||||
Time time.Time
|
||||
ReceivedAt time.Time
|
||||
}
|
||||
|
||||
func (o OrderBook) BestBid() (decimal.Decimal, bool) {
|
||||
if len(o.Bids) == 0 {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
return o.Bids[0].Price, true
|
||||
}
|
||||
|
||||
func (o OrderBook) BestAsk() (decimal.Decimal, bool) {
|
||||
if len(o.Asks) == 0 {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
return o.Asks[0].Price, true
|
||||
}
|
||||
|
||||
type Operation struct {
|
||||
ID string
|
||||
InstrumentUID string
|
||||
Type string
|
||||
Payment decimal.Decimal
|
||||
Commission decimal.Decimal
|
||||
ExecutedAt time.Time
|
||||
}
|
||||
|
||||
type ReconciliationDiff struct {
|
||||
Kind string
|
||||
InstrumentUID string
|
||||
Message string
|
||||
Critical bool
|
||||
}
|
||||
@@ -0,0 +1,405 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
)
|
||||
|
||||
var ErrBrokerOrdersDisabled = errors.New("broker orders are disabled for current mode")
|
||||
var ErrEmptyOrderBook = errors.New("order book has no usable bid/ask")
|
||||
|
||||
type Gateway interface {
|
||||
PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error)
|
||||
CancelOrder(ctx context.Context, accountID, orderID string) error
|
||||
GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error)
|
||||
}
|
||||
|
||||
type Engine struct {
|
||||
mode domain.Mode
|
||||
accountID string
|
||||
gateway Gateway
|
||||
store repository.Repository
|
||||
maxQuoteAge time.Duration
|
||||
mu sync.Map
|
||||
}
|
||||
|
||||
type MonitorConfig struct {
|
||||
Deadline time.Time
|
||||
PollInterval time.Duration
|
||||
MaxAttempts int
|
||||
RepostAfter time.Duration
|
||||
Instrument domain.Instrument
|
||||
ImproveTicks int
|
||||
Quote func(ctx context.Context, instrumentUID string) (domain.OrderBook, error)
|
||||
}
|
||||
|
||||
func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine {
|
||||
return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store}
|
||||
}
|
||||
|
||||
func (e *Engine) SetMaxQuoteAge(maxQuoteAge time.Duration) {
|
||||
e.maxQuoteAge = maxQuoteAge
|
||||
}
|
||||
|
||||
func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) {
|
||||
if err := e.checkQuoteFresh(book); err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
bid, ask, err := bestBidAsk(book)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
price, err := LimitBuyPrice(bid, ask, instrument.MinPriceIncrement, improveTicks)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
return e.PlaceLimit(ctx, domain.Order{
|
||||
ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideBuy, attempt),
|
||||
AccountIDHash: accountIDHash,
|
||||
InstrumentUID: instrument.InstrumentUID,
|
||||
TradeDate: tradeDate,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: price,
|
||||
QuantityLots: lots,
|
||||
Status: domain.OrderStatusNew,
|
||||
AttemptNo: attempt,
|
||||
RawStateJSON: "{}",
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) {
|
||||
if err := e.checkQuoteFresh(book); err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
bid, ask, err := bestBidAsk(book)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
price, err := LimitSellPrice(bid, ask, instrument.MinPriceIncrement, improveTicks)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
return e.PlaceLimit(ctx, domain.Order{
|
||||
ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideSell, attempt),
|
||||
AccountIDHash: accountIDHash,
|
||||
InstrumentUID: instrument.InstrumentUID,
|
||||
TradeDate: tradeDate,
|
||||
Side: domain.SideSell,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: price,
|
||||
QuantityLots: lots,
|
||||
Status: domain.OrderStatusNew,
|
||||
AttemptNo: attempt,
|
||||
RawStateJSON: "{}",
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Order, error) {
|
||||
if e.store != nil {
|
||||
existing, err := e.findExisting(ctx, order)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
if existing.ClientOrderID != "" {
|
||||
return existing, nil
|
||||
}
|
||||
}
|
||||
if !e.mode.AllowsBrokerOrders() {
|
||||
order.Status = domain.OrderStatusNew
|
||||
if e.store != nil {
|
||||
return order, e.store.UpsertOrder(ctx, order)
|
||||
}
|
||||
return order, ErrBrokerOrdersDisabled
|
||||
}
|
||||
if e.gateway == nil {
|
||||
return domain.Order{}, errors.New("gateway is nil")
|
||||
}
|
||||
lock := e.lockFor(order.InstrumentUID)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
posted, err := e.gateway.PostLimitOrder(ctx, e.accountID, order.InstrumentUID, order.Side, order.QuantityLots, order.LimitPrice, order.ClientOrderID)
|
||||
if err != nil {
|
||||
order.Status = domain.OrderStatusFailed
|
||||
if e.store != nil {
|
||||
_ = e.store.UpsertOrder(ctx, order)
|
||||
}
|
||||
return domain.Order{}, err
|
||||
}
|
||||
posted.ClientOrderID = order.ClientOrderID
|
||||
posted.AccountIDHash = order.AccountIDHash
|
||||
posted.InstrumentUID = order.InstrumentUID
|
||||
posted.Side = order.Side
|
||||
posted.OrderType = order.OrderType
|
||||
posted.LimitPrice = order.LimitPrice
|
||||
posted.QuantityLots = order.QuantityLots
|
||||
posted.AttemptNo = order.AttemptNo
|
||||
posted.TradeDate = order.TradeDate
|
||||
posted.CreatedAt = time.Now().UTC()
|
||||
posted.UpdatedAt = posted.CreatedAt
|
||||
if e.store != nil {
|
||||
if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
|
||||
if err := repo.UpsertOrder(ctx, posted); err != nil {
|
||||
return fmt.Errorf("persist posted order: %w", err)
|
||||
}
|
||||
return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1)
|
||||
}); err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
}
|
||||
return posted, nil
|
||||
}
|
||||
|
||||
func (e *Engine) findExisting(ctx context.Context, order domain.Order) (domain.Order, error) {
|
||||
orders, err := e.store.ListOrders(ctx, order.AccountIDHash, order.TradeDate, order.TradeDate)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
for _, existing := range orders {
|
||||
if existing.ClientOrderID == order.ClientOrderID &&
|
||||
existing.Status != domain.OrderStatusFailed &&
|
||||
existing.Status != domain.OrderStatusRejected {
|
||||
return existing, nil
|
||||
}
|
||||
}
|
||||
return domain.Order{}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) Refresh(ctx context.Context, order domain.Order) (domain.Order, error) {
|
||||
if e.gateway == nil {
|
||||
return domain.Order{}, errors.New("gateway is nil")
|
||||
}
|
||||
lock := e.lockFor(order.InstrumentUID)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
state, err := e.gateway.GetOrderState(ctx, e.accountID, order.BrokerOrderID)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
state.ClientOrderID = order.ClientOrderID
|
||||
state.AccountIDHash = order.AccountIDHash
|
||||
state.InstrumentUID = order.InstrumentUID
|
||||
state.TradeDate = order.TradeDate
|
||||
state.Side = order.Side
|
||||
state.OrderType = order.OrderType
|
||||
state.LimitPrice = order.LimitPrice
|
||||
state.QuantityLots = order.QuantityLots
|
||||
state.AttemptNo = order.AttemptNo
|
||||
if e.store != nil {
|
||||
if err := e.store.UpsertOrder(ctx, state); err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (e *Engine) Cancel(ctx context.Context, order domain.Order) error {
|
||||
if e.gateway == nil {
|
||||
return errors.New("gateway is nil")
|
||||
}
|
||||
lock := e.lockFor(order.InstrumentUID)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if err := e.gateway.CancelOrder(ctx, e.accountID, order.BrokerOrderID); err != nil {
|
||||
return err
|
||||
}
|
||||
if e.store != nil {
|
||||
return e.store.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg MonitorConfig) (domain.Order, error) {
|
||||
if cfg.PollInterval <= 0 {
|
||||
cfg.PollInterval = 500 * time.Millisecond
|
||||
}
|
||||
if cfg.MaxAttempts <= 0 {
|
||||
cfg.MaxAttempts = 1
|
||||
}
|
||||
lastPost := time.Now()
|
||||
current := order
|
||||
aggregate := order
|
||||
seen := map[string]domain.Order{order.ClientOrderID: order}
|
||||
ticker := time.NewTicker(cfg.PollInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
previous := seen[current.ClientOrderID]
|
||||
refreshed, err := e.Refresh(ctx, current)
|
||||
if err != nil {
|
||||
return aggregate, err
|
||||
}
|
||||
aggregate = mergeAggregateFill(aggregate, previous, refreshed)
|
||||
seen[current.ClientOrderID] = refreshed
|
||||
current = mergeOrderState(current, refreshed)
|
||||
aggregate.Status = current.Status
|
||||
aggregate.UpdatedAt = current.UpdatedAt
|
||||
aggregate.RawStateJSON = current.RawStateJSON
|
||||
if aggregate.FilledLots >= aggregate.QuantityLots {
|
||||
aggregate.Status = domain.OrderStatusFilled
|
||||
return aggregate, nil
|
||||
}
|
||||
if isTerminal(current.Status) {
|
||||
return aggregate, nil
|
||||
}
|
||||
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
|
||||
if err := e.Cancel(ctx, current); err != nil {
|
||||
return aggregate, err
|
||||
}
|
||||
aggregate.Status = domain.OrderStatusExpired
|
||||
if e.store != nil {
|
||||
if err := e.store.UpdateOrderStatus(ctx, current.ClientOrderID, aggregate.Status, current.FilledLots, current.RawStateJSON); err != nil {
|
||||
return aggregate, err
|
||||
}
|
||||
}
|
||||
return aggregate, nil
|
||||
}
|
||||
shouldRepost := cfg.RepostAfter > 0 &&
|
||||
time.Since(lastPost) >= cfg.RepostAfter &&
|
||||
current.AttemptNo < cfg.MaxAttempts &&
|
||||
aggregate.FilledLots < aggregate.QuantityLots &&
|
||||
cfg.Quote != nil
|
||||
if shouldRepost {
|
||||
next, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots)
|
||||
if err != nil {
|
||||
return aggregate, err
|
||||
}
|
||||
current = next
|
||||
seen[current.ClientOrderID] = current
|
||||
lastPost = time.Now()
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return aggregate, ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConfig, remaining int64) (domain.Order, error) {
|
||||
if err := e.Cancel(ctx, order); err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
if remaining <= 0 {
|
||||
order.Status = domain.OrderStatusFilled
|
||||
return order, nil
|
||||
}
|
||||
book, err := cfg.Quote(ctx, order.InstrumentUID)
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
attempt := order.AttemptNo + 1
|
||||
switch order.Side {
|
||||
case domain.SideBuy:
|
||||
return e.PlaceEntry(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt)
|
||||
case domain.SideSell:
|
||||
return e.PlaceExit(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt)
|
||||
default:
|
||||
return domain.Order{}, fmt.Errorf("unsupported side %s", order.Side)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) checkQuoteFresh(book domain.OrderBook) error {
|
||||
if e.maxQuoteAge <= 0 {
|
||||
return nil
|
||||
}
|
||||
receivedAt := book.ReceivedAt
|
||||
if receivedAt.IsZero() {
|
||||
receivedAt = book.Time
|
||||
}
|
||||
if receivedAt.IsZero() {
|
||||
return fmt.Errorf("quote timestamp is missing")
|
||||
}
|
||||
age := time.Since(receivedAt)
|
||||
if age > e.maxQuoteAge {
|
||||
return fmt.Errorf("quote age %s exceeds %s", age, e.maxQuoteAge)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) lockFor(instrumentUID string) *sync.Mutex {
|
||||
value, _ := e.mu.LoadOrStore(instrumentUID, &sync.Mutex{})
|
||||
lock, ok := value.(*sync.Mutex)
|
||||
if !ok {
|
||||
panic("execution lock has unexpected type")
|
||||
}
|
||||
return lock
|
||||
}
|
||||
|
||||
func bestBidAsk(book domain.OrderBook) (decimal.Decimal, decimal.Decimal, error) {
|
||||
bid, ok := book.BestBid()
|
||||
if !ok {
|
||||
return decimal.Zero, decimal.Zero, ErrEmptyOrderBook
|
||||
}
|
||||
ask, ok := book.BestAsk()
|
||||
if !ok {
|
||||
return decimal.Zero, decimal.Zero, ErrEmptyOrderBook
|
||||
}
|
||||
return bid, ask, nil
|
||||
}
|
||||
|
||||
func isTerminal(status domain.OrderStatus) bool {
|
||||
switch status {
|
||||
case domain.OrderStatusFilled, domain.OrderStatusCancelled, domain.OrderStatusRejected, domain.OrderStatusExpired, domain.OrderStatusFailed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func mergeOrderState(base, state domain.Order) domain.Order {
|
||||
base.BrokerOrderID = state.BrokerOrderID
|
||||
base.FilledLots = state.FilledLots
|
||||
base.AvgFillPrice = state.AvgFillPrice
|
||||
base.Status = state.Status
|
||||
base.Commission = state.Commission
|
||||
base.RawStateJSON = state.RawStateJSON
|
||||
base.UpdatedAt = state.UpdatedAt
|
||||
return base
|
||||
}
|
||||
|
||||
func mergeAggregateFill(aggregate, previous, current domain.Order) domain.Order {
|
||||
deltaLots := current.FilledLots - previous.FilledLots
|
||||
if deltaLots > 0 {
|
||||
deltaAvg := fillDeltaAvg(previous, current, deltaLots)
|
||||
previousValue := aggregate.AvgFillPrice.Mul(decimal.NewFromInt(aggregate.FilledLots))
|
||||
deltaValue := deltaAvg.Mul(decimal.NewFromInt(deltaLots))
|
||||
aggregate.FilledLots += deltaLots
|
||||
aggregate.AvgFillPrice = previousValue.Add(deltaValue).Div(decimal.NewFromInt(aggregate.FilledLots))
|
||||
}
|
||||
deltaCommission := current.Commission.Sub(previous.Commission)
|
||||
if deltaCommission.IsPositive() {
|
||||
aggregate.Commission = aggregate.Commission.Add(deltaCommission)
|
||||
}
|
||||
return aggregate
|
||||
}
|
||||
|
||||
func fillDeltaAvg(previous, current domain.Order, deltaLots int64) decimal.Decimal {
|
||||
if deltaLots <= 0 {
|
||||
return decimal.Zero
|
||||
}
|
||||
if previous.FilledLots <= 0 {
|
||||
if current.AvgFillPrice.IsPositive() {
|
||||
return current.AvgFillPrice
|
||||
}
|
||||
return current.LimitPrice
|
||||
}
|
||||
currentValue := current.AvgFillPrice.Mul(decimal.NewFromInt(current.FilledLots))
|
||||
previousValue := previous.AvgFillPrice.Mul(decimal.NewFromInt(previous.FilledLots))
|
||||
if currentValue.GreaterThan(previousValue) {
|
||||
return currentValue.Sub(previousValue).Div(decimal.NewFromInt(deltaLots))
|
||||
}
|
||||
if current.AvgFillPrice.IsPositive() {
|
||||
return current.AvgFillPrice
|
||||
}
|
||||
return current.LimitPrice
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/money"
|
||||
)
|
||||
|
||||
var nonIDChar = regexp.MustCompile(`[^A-Za-z0-9_-]+`)
|
||||
|
||||
func LimitBuyPrice(bestBid, bestAsk, tick decimal.Decimal, improveTicks int) (decimal.Decimal, error) {
|
||||
if improveTicks < 0 {
|
||||
improveTicks = 0
|
||||
}
|
||||
if !tick.IsPositive() {
|
||||
return decimal.Zero, money.ErrInvalidTick
|
||||
}
|
||||
candidate := bestBid.Add(tick.Mul(decimal.NewFromInt(int64(improveTicks))))
|
||||
upper := bestAsk.Sub(tick)
|
||||
if candidate.LessThanOrEqual(upper) {
|
||||
return money.RoundToTick(candidate, tick, money.RoundFloor)
|
||||
}
|
||||
return money.RoundToTick(bestBid, tick, money.RoundFloor)
|
||||
}
|
||||
|
||||
func LimitSellPrice(bestBid, bestAsk, tick decimal.Decimal, improveTicks int) (decimal.Decimal, error) {
|
||||
if improveTicks < 0 {
|
||||
improveTicks = 0
|
||||
}
|
||||
if !tick.IsPositive() {
|
||||
return decimal.Zero, money.ErrInvalidTick
|
||||
}
|
||||
candidate := bestAsk.Sub(tick.Mul(decimal.NewFromInt(int64(improveTicks))))
|
||||
lower := bestBid.Add(tick)
|
||||
if candidate.GreaterThanOrEqual(lower) {
|
||||
return money.RoundToTick(candidate, tick, money.RoundCeil)
|
||||
}
|
||||
return money.RoundToTick(bestAsk, tick, money.RoundCeil)
|
||||
}
|
||||
|
||||
func ClientOrderID(tradeDate time.Time, instrumentUID string, side domain.Side, attempt int) string {
|
||||
base := fmt.Sprintf("%s|%s|%s|%d", tradeDate.Format("20060102"), instrumentUID, side, attempt)
|
||||
sum := sha256.Sum256([]byte(base))
|
||||
suffix := hex.EncodeToString(sum[:])[:8]
|
||||
cleanUID := nonIDChar.ReplaceAllString(instrumentUID, "_")
|
||||
if len(cleanUID) > 24 {
|
||||
cleanUID = cleanUID[:24]
|
||||
}
|
||||
return strings.ToLower(fmt.Sprintf("otb-%s-%s-%s-%02d-%s", tradeDate.Format("20060102"), cleanUID, side, attempt, suffix))
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func ed(raw string) decimal.Decimal {
|
||||
v, err := decimal.NewFromString(raw)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func TestLimitPricesDoNotCross(t *testing.T) {
|
||||
buy, err := LimitBuyPrice(ed("100"), ed("100.03"), ed("0.01"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !buy.Equal(ed("100.01")) {
|
||||
t.Fatalf("buy=%s", buy)
|
||||
}
|
||||
sell, err := LimitSellPrice(ed("100"), ed("100.03"), ed("0.01"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sell.Equal(ed("100.02")) {
|
||||
t.Fatalf("sell=%s", sell)
|
||||
}
|
||||
tightBuy, _ := LimitBuyPrice(ed("100"), ed("100.01"), ed("0.01"), 1)
|
||||
if !tightBuy.Equal(ed("100")) {
|
||||
t.Fatalf("tight buy=%s", tightBuy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientOrderIDDeterministic(t *testing.T) {
|
||||
date := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
a := ClientOrderID(date, "uid", domain.SideBuy, 1)
|
||||
b := ClientOrderID(date, "uid", domain.SideBuy, 1)
|
||||
c := ClientOrderID(date, "uid", domain.SideBuy, 2)
|
||||
if a != b || a == c {
|
||||
t.Fatalf("unexpected ids: %s %s %s", a, b, c)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/testutil"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
func TestClientOrderIDIncludesAttempt(t *testing.T) {
|
||||
date := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
|
||||
first := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 1)
|
||||
second := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 1)
|
||||
third := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 2)
|
||||
if first != second {
|
||||
t.Fatalf("client order id is not deterministic: %s != %s", first, second)
|
||||
}
|
||||
if first == third {
|
||||
t.Fatalf("attempt is not part of client order id: %s", first)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlaceLimitSuppressesDuplicateSubmit(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
|
||||
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
|
||||
order := domain.Order{
|
||||
ClientOrderID: "order-1",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: tradeDate,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: decimal.NewFromInt(100),
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusNew,
|
||||
AttemptNo: 1,
|
||||
}
|
||||
first, err := engine.PlaceLimit(ctx, order)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
second, err := engine.PlaceLimit(ctx, order)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if first.BrokerOrderID != second.BrokerOrderID {
|
||||
t.Fatalf("duplicate submit posted a new broker order: %s != %s", first.BrokerOrderID, second.BrokerOrderID)
|
||||
}
|
||||
if got := len(gateway.Orders); got != 1 {
|
||||
t.Fatalf("broker posts=%d, want 1", got)
|
||||
}
|
||||
sent, err := repo.GetFreeOrdersSent(ctx, tradeDate, "uid")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sent != 1 {
|
||||
t.Fatalf("free order counter=%d, want 1", sent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlaceEntryRejectsStaleQuote(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
engine := NewEngine(domain.ModeSandbox, "account", tinvest.NewFakeGateway(), testutil.NewMemoryRepository())
|
||||
engine.SetMaxQuoteAge(time.Second)
|
||||
_, err := engine.PlaceEntry(ctx, "hash", domain.Instrument{
|
||||
InstrumentUID: "uid",
|
||||
Lot: 1,
|
||||
MinPriceIncrement: decimal.NewFromInt(1),
|
||||
}, time.Now().UTC(), 1, domain.OrderBook{
|
||||
InstrumentUID: "uid",
|
||||
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}},
|
||||
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}},
|
||||
ReceivedAt: time.Now().UTC().Add(-2 * time.Second),
|
||||
}, 1, 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected stale quote error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitorUntilRepostsAndExpiresAtDeadline(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
|
||||
instrument := domain.Instrument{
|
||||
InstrumentUID: "uid",
|
||||
Lot: 1,
|
||||
MinPriceIncrement: decimal.NewFromInt(1),
|
||||
}
|
||||
book := domain.OrderBook{
|
||||
InstrumentUID: "uid",
|
||||
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}},
|
||||
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}},
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
}
|
||||
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
|
||||
order, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 3, book, 1, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
monitored, err := engine.MonitorUntil(ctx, order, MonitorConfig{
|
||||
Deadline: time.Now().Add(20 * time.Millisecond),
|
||||
PollInterval: time.Millisecond,
|
||||
MaxAttempts: 2,
|
||||
RepostAfter: time.Nanosecond,
|
||||
Instrument: instrument,
|
||||
ImproveTicks: 1,
|
||||
Quote: func(context.Context, string) (domain.OrderBook, error) {
|
||||
book.ReceivedAt = time.Now().UTC()
|
||||
return book, nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if monitored.Status != domain.OrderStatusExpired {
|
||||
t.Fatalf("status=%s, want EXPIRED", monitored.Status)
|
||||
}
|
||||
if got := len(gateway.Orders); got < 2 {
|
||||
t.Fatalf("broker orders=%d, want repost attempt", got)
|
||||
}
|
||||
sent, err := repo.GetFreeOrdersSent(ctx, tradeDate, "uid")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sent != 2 {
|
||||
t.Fatalf("free order counter=%d, want 2", sent)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
"overnight-trading-bot/internal/timeutil"
|
||||
)
|
||||
|
||||
type PipelineConfig struct {
|
||||
RollingShort int
|
||||
RollingLong int
|
||||
EWMALambda float64
|
||||
RiskBufferBps decimal.Decimal
|
||||
EntrySlippageBps decimal.Decimal
|
||||
ExitSlippageBps decimal.Decimal
|
||||
CommissionRoundtripBps decimal.Decimal
|
||||
EntryWindow timeutil.Window
|
||||
ExitWindow timeutil.Window
|
||||
Location *time.Location
|
||||
}
|
||||
|
||||
type Pipeline struct {
|
||||
repo repository.Repository
|
||||
cfg PipelineConfig
|
||||
}
|
||||
|
||||
func NewPipeline(repo repository.Repository, cfg PipelineConfig) Pipeline {
|
||||
return Pipeline{repo: repo, cfg: cfg}
|
||||
}
|
||||
|
||||
func (p Pipeline) Recompute(ctx context.Context, instrument domain.Instrument, tradeDate time.Time, spread SpreadResult) (domain.FeatureSet, error) {
|
||||
from := tradeDate.AddDate(0, 0, -p.cfg.RollingLong-5)
|
||||
candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, tradeDate)
|
||||
if err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
entryVolume, err := p.intervalVolume(ctx, instrument, tradeDate, p.cfg.EntryWindow)
|
||||
if err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
exitVolume, err := p.intervalVolume(ctx, instrument, tradeDate.AddDate(0, 0, 1), p.cfg.ExitWindow)
|
||||
if err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
feature, err := Compute(instrument, candles, tradeDate, spread, p.cfg, entryVolume, exitVolume)
|
||||
if err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
if err := p.repo.UpsertFeature(ctx, feature); err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
return feature, nil
|
||||
}
|
||||
|
||||
func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrument, date time.Time, window timeutil.Window) (decimal.Decimal, error) {
|
||||
if window.Start.Duration == 0 && window.End.Duration == 0 {
|
||||
return decimal.Zero, nil
|
||||
}
|
||||
loc := p.cfg.Location
|
||||
if loc == nil {
|
||||
loc = time.UTC
|
||||
}
|
||||
from := window.Start.On(date, loc).UTC()
|
||||
to := window.End.On(date, loc).UTC()
|
||||
candles, err := p.repo.ListMinuteCandles(ctx, instrument.InstrumentUID, from, to)
|
||||
if err != nil {
|
||||
return decimal.Zero, err
|
||||
}
|
||||
return IntervalVolume(candles, instrument.Lot), nil
|
||||
}
|
||||
|
||||
func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate time.Time, spread SpreadResult, cfg PipelineConfig, entryVolume, exitVolume decimal.Decimal) (domain.FeatureSet, error) {
|
||||
if len(candles) < 2 {
|
||||
return domain.FeatureSet{}, fmt.Errorf("need at least 2 candles, got %d", len(candles))
|
||||
}
|
||||
var overnight []float64
|
||||
var lastROn decimal.Decimal
|
||||
var lastRDay decimal.Decimal
|
||||
for i := 1; i < len(candles); i++ {
|
||||
rOn, err := OvernightReturn(candles[i].Open, candles[i-1].Close)
|
||||
if err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
rDay, err := IntradayReturn(candles[i].Close, candles[i].Open)
|
||||
if err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
onFloat, _ := rOn.Float64()
|
||||
overnight = append(overnight, onFloat)
|
||||
lastROn = rOn
|
||||
lastRDay = rDay
|
||||
}
|
||||
short := Rolling(overnight, cfg.RollingShort, cfg.EWMALambda)
|
||||
long := Rolling(overnight, cfg.RollingLong, cfg.EWMALambda)
|
||||
adv := ADV(candles, instrument.Lot, 20)
|
||||
rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
|
||||
if !entryVolume.IsPositive() {
|
||||
entryVolume = adv
|
||||
}
|
||||
if !exitVolume.IsPositive() {
|
||||
exitVolume = adv
|
||||
}
|
||||
instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2))
|
||||
expectedCost := spread.SpreadBps.
|
||||
Add(cfg.EntrySlippageBps).
|
||||
Add(cfg.ExitSlippageBps).
|
||||
Add(cfg.CommissionRoundtripBps).
|
||||
Add(instrumentCommission).
|
||||
Add(cfg.RiskBufferBps)
|
||||
return domain.FeatureSet{
|
||||
InstrumentUID: instrument.InstrumentUID,
|
||||
TradeDate: tradeDate,
|
||||
ROn: lastROn,
|
||||
RDay: lastRDay,
|
||||
MuOn60: decimal.NewFromFloat(short.Mean),
|
||||
MuOn252: decimal.NewFromFloat(long.Mean),
|
||||
SigmaOn60: decimal.NewFromFloat(short.StdDev),
|
||||
TStatOn60: decimal.NewFromFloat(short.TStat),
|
||||
WinOn60: decimal.NewFromFloat(short.WinRate),
|
||||
EWMAOn: decimal.NewFromFloat(short.EWMA),
|
||||
SpreadBps: spread.SpreadBps,
|
||||
HalfSpreadBps: spread.HalfSpreadBps,
|
||||
TickBps: spread.TickBps,
|
||||
ADV20: adv,
|
||||
ExpectedCostBps: expectedCost,
|
||||
NetEdgeBps: rawEdgeBps.Sub(expectedCost),
|
||||
EntryIntervalVolume: entryVolume,
|
||||
ExitIntervalVolume: exitVolume,
|
||||
CalculatedAt: time.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func IntervalVolume(candles []domain.Candle, lot int64) decimal.Decimal {
|
||||
if lot <= 0 {
|
||||
return decimal.Zero
|
||||
}
|
||||
total := decimal.Zero
|
||||
for _, candle := range candles {
|
||||
total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close))
|
||||
}
|
||||
return total
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) {
|
||||
var candles []domain.Candle
|
||||
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
for i := 0; i < 6; i++ {
|
||||
price := decimal.NewFromInt(int64(100 + i))
|
||||
candles = append(candles, domain.Candle{
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: start.AddDate(0, 0, i),
|
||||
Open: price,
|
||||
Close: price,
|
||||
VolumeLots: decimal.NewFromInt(1000),
|
||||
})
|
||||
}
|
||||
got, err := Compute(domain.Instrument{
|
||||
InstrumentUID: "uid",
|
||||
Lot: 1,
|
||||
ExpectedCommissionBpsPerSide: decimal.NewFromInt(1),
|
||||
}, candles, start.AddDate(0, 0, 5), SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{
|
||||
RollingShort: 2,
|
||||
RollingLong: 2,
|
||||
EWMALambda: 0.08,
|
||||
RiskBufferBps: decimal.NewFromInt(5),
|
||||
EntrySlippageBps: decimal.NewFromInt(2),
|
||||
ExitSlippageBps: decimal.NewFromInt(3),
|
||||
CommissionRoundtripBps: decimal.NewFromInt(4),
|
||||
}, decimal.NewFromInt(10000), decimal.NewFromInt(9000))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(26)) {
|
||||
t.Fatalf("expected cost=%s, want 26", got.ExpectedCostBps)
|
||||
}
|
||||
if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) {
|
||||
t.Fatalf("interval volumes were not preserved: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntervalVolume(t *testing.T) {
|
||||
got := IntervalVolume([]domain.Candle{
|
||||
{Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
|
||||
{Close: decimal.NewFromInt(101), VolumeLots: decimal.NewFromInt(20)},
|
||||
}, 2)
|
||||
if !got.Equal(decimal.NewFromInt(6040)) {
|
||||
t.Fatalf("interval volume=%s, want 6040", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/money"
|
||||
)
|
||||
|
||||
var ErrInvalidPrice = errors.New("price must be positive")
|
||||
|
||||
func OvernightReturn(open, previousClose decimal.Decimal) (decimal.Decimal, error) {
|
||||
if !open.IsPositive() || !previousClose.IsPositive() {
|
||||
return decimal.Zero, ErrInvalidPrice
|
||||
}
|
||||
return open.Div(previousClose).Sub(decimal.NewFromInt(1)), nil
|
||||
}
|
||||
|
||||
func IntradayReturn(close, open decimal.Decimal) (decimal.Decimal, error) {
|
||||
if !close.IsPositive() || !open.IsPositive() {
|
||||
return decimal.Zero, ErrInvalidPrice
|
||||
}
|
||||
return close.Div(open).Sub(decimal.NewFromInt(1)), nil
|
||||
}
|
||||
|
||||
func LogReturn(to, from decimal.Decimal) (float64, error) {
|
||||
if !to.IsPositive() || !from.IsPositive() {
|
||||
return 0, ErrInvalidPrice
|
||||
}
|
||||
ratio, _ := to.Div(from).Float64()
|
||||
return math.Log(ratio), nil
|
||||
}
|
||||
|
||||
func CumulativeLinear(returns []decimal.Decimal) decimal.Decimal {
|
||||
total := decimal.NewFromInt(1)
|
||||
for _, r := range returns {
|
||||
total = total.Mul(decimal.NewFromInt(1).Add(r))
|
||||
}
|
||||
return total.Sub(decimal.NewFromInt(1))
|
||||
}
|
||||
|
||||
func CumulativeLog(logReturns []float64) float64 {
|
||||
sum := 0.0
|
||||
for _, r := range logReturns {
|
||||
sum += r
|
||||
}
|
||||
return math.Exp(sum) - 1
|
||||
}
|
||||
|
||||
type RollingResult struct {
|
||||
Mean float64
|
||||
StdDev float64
|
||||
TStat float64
|
||||
WinRate float64
|
||||
EWMA float64
|
||||
Available bool
|
||||
}
|
||||
|
||||
func Rolling(values []float64, window int, lambda float64) RollingResult {
|
||||
if window <= 0 || len(values) < window {
|
||||
return RollingResult{}
|
||||
}
|
||||
sample := values[len(values)-window:]
|
||||
mean := Mean(sample)
|
||||
std := StdDev(sample)
|
||||
win := WinRate(sample)
|
||||
ewma := EWMA(values, lambda)
|
||||
res := RollingResult{
|
||||
Mean: mean,
|
||||
StdDev: std,
|
||||
WinRate: win,
|
||||
EWMA: ewma,
|
||||
Available: true,
|
||||
}
|
||||
if std > 0 {
|
||||
res.TStat = mean / std * math.Sqrt(float64(window))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func Mean(values []float64) float64 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
sum := 0.0
|
||||
for _, value := range values {
|
||||
sum += value
|
||||
}
|
||||
return sum / float64(len(values))
|
||||
}
|
||||
|
||||
func StdDev(values []float64) float64 {
|
||||
if len(values) < 2 {
|
||||
return 0
|
||||
}
|
||||
mean := Mean(values)
|
||||
sum := 0.0
|
||||
for _, value := range values {
|
||||
diff := value - mean
|
||||
sum += diff * diff
|
||||
}
|
||||
return math.Sqrt(sum / float64(len(values)-1))
|
||||
}
|
||||
|
||||
func WinRate(values []float64) float64 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
wins := 0
|
||||
for _, value := range values {
|
||||
if value > 0 {
|
||||
wins++
|
||||
}
|
||||
}
|
||||
return float64(wins) / float64(len(values))
|
||||
}
|
||||
|
||||
func EWMA(values []float64, lambda float64) float64 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
if lambda <= 0 || lambda > 1 {
|
||||
lambda = 0.08
|
||||
}
|
||||
ewma := values[0]
|
||||
for _, value := range values[1:] {
|
||||
ewma = lambda*value + (1-lambda)*ewma
|
||||
}
|
||||
return ewma
|
||||
}
|
||||
|
||||
type SpreadResult struct {
|
||||
SpreadAbs decimal.Decimal
|
||||
SpreadBps decimal.Decimal
|
||||
HalfSpreadBps decimal.Decimal
|
||||
TickBps decimal.Decimal
|
||||
Mid decimal.Decimal
|
||||
}
|
||||
|
||||
func Spread(bestBid, bestAsk, tick decimal.Decimal) (SpreadResult, error) {
|
||||
if !bestBid.IsPositive() || !bestAsk.IsPositive() || bestAsk.LessThanOrEqual(bestBid) {
|
||||
return SpreadResult{}, ErrInvalidPrice
|
||||
}
|
||||
mid := bestAsk.Add(bestBid).Div(decimal.NewFromInt(2))
|
||||
spreadAbs := bestAsk.Sub(bestBid)
|
||||
spreadBps, err := money.Bps(spreadAbs, mid)
|
||||
if err != nil {
|
||||
return SpreadResult{}, err
|
||||
}
|
||||
tickBps := decimal.Zero
|
||||
if tick.IsPositive() {
|
||||
tickBps, err = money.Bps(tick, mid)
|
||||
if err != nil {
|
||||
return SpreadResult{}, err
|
||||
}
|
||||
}
|
||||
return SpreadResult{
|
||||
SpreadAbs: spreadAbs,
|
||||
SpreadBps: spreadBps,
|
||||
HalfSpreadBps: spreadBps.Div(decimal.NewFromInt(2)),
|
||||
TickBps: tickBps,
|
||||
Mid: mid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ADV(candles []domain.Candle, lot int64, window int) decimal.Decimal {
|
||||
if lot <= 0 || window <= 0 || len(candles) == 0 {
|
||||
return decimal.Zero
|
||||
}
|
||||
sort.Slice(candles, func(i, j int) bool {
|
||||
return candles[i].TradeDate.Before(candles[j].TradeDate)
|
||||
})
|
||||
if len(candles) > window {
|
||||
candles = candles[len(candles)-window:]
|
||||
}
|
||||
total := decimal.Zero
|
||||
for _, candle := range candles {
|
||||
total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close))
|
||||
}
|
||||
return total.Div(decimal.NewFromInt(int64(len(candles))))
|
||||
}
|
||||
|
||||
func Quantile(values []float64, q float64) float64 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
cp := append([]float64(nil), values...)
|
||||
sort.Float64s(cp)
|
||||
if q <= 0 {
|
||||
return cp[0]
|
||||
}
|
||||
if q >= 1 {
|
||||
return cp[len(cp)-1]
|
||||
}
|
||||
pos := q * float64(len(cp)-1)
|
||||
lower := int(math.Floor(pos))
|
||||
upper := int(math.Ceil(pos))
|
||||
if lower == upper {
|
||||
return cp[lower]
|
||||
}
|
||||
weight := pos - float64(lower)
|
||||
return cp[lower]*(1-weight) + cp[upper]*weight
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func dec(raw string) decimal.Decimal {
|
||||
v, err := decimal.NewFromString(raw)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func TestReturnsAndLogIdentity(t *testing.T) {
|
||||
rOn, err := OvernightReturn(dec("102"), dec("100"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !rOn.Equal(dec("0.02")) {
|
||||
t.Fatalf("overnight return=%s", rOn)
|
||||
}
|
||||
rDay, err := IntradayReturn(dec("105"), dec("102"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !rDay.Round(10).Equal(dec("0.0294117647")) {
|
||||
t.Fatalf("intraday return=%s", rDay)
|
||||
}
|
||||
linear := CumulativeLinear([]decimal.Decimal{dec("0.01"), dec("-0.02"), dec("0.03")})
|
||||
logs := []float64{math.Log(1.01), math.Log(0.98), math.Log(1.03)}
|
||||
if math.Abs(linear.InexactFloat64()-CumulativeLog(logs)) > 1e-10 {
|
||||
t.Fatalf("linear/log cumulative mismatch")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRollingStats(t *testing.T) {
|
||||
values := []float64{0.01, -0.01, 0.02, 0.03}
|
||||
got := Rolling(values, 4, 0.5)
|
||||
if !got.Available {
|
||||
t.Fatal("expected rolling result")
|
||||
}
|
||||
if math.Abs(got.Mean-0.0125) > 1e-12 {
|
||||
t.Fatalf("mean=%f", got.Mean)
|
||||
}
|
||||
if math.Abs(got.WinRate-0.75) > 1e-12 {
|
||||
t.Fatalf("win=%f", got.WinRate)
|
||||
}
|
||||
if got.StdDev <= 0 || got.TStat <= 0 {
|
||||
t.Fatalf("std/tstat invalid: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollingSigmaZero(t *testing.T) {
|
||||
got := Rolling([]float64{0.01, 0.01, 0.01}, 3, 0.08)
|
||||
if got.StdDev != 0 || got.TStat != 0 {
|
||||
t.Fatalf("expected zero sigma/tstat, got %+v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package features
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSpread(t *testing.T) {
|
||||
got, err := Spread(dec("99"), dec("101"), dec("0.1"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !got.Mid.Equal(dec("100")) || !got.SpreadBps.Equal(dec("200")) || !got.HalfSpreadBps.Equal(dec("100")) || !got.TickBps.Equal(dec("10")) {
|
||||
t.Fatalf("unexpected spread: %+v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"overnight-trading-bot/internal/timeutil"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
db *sql.DB
|
||||
gateway tinvest.Gateway
|
||||
maxDrift time.Duration
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func New(db *sql.DB, gateway tinvest.Gateway, maxDrift time.Duration) *Service {
|
||||
return &Service{db: db, gateway: gateway, maxDrift: maxDrift}
|
||||
}
|
||||
|
||||
func (s *Service) Start(addr string) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
mux.HandleFunc("/ready", s.handleReady)
|
||||
s.server = &http.Server{Addr: addr, Handler: mux, ReadHeaderTimeout: 3 * time.Second}
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
// HTTP health errors are intentionally surfaced through /ready and logs by caller.
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) Shutdown(ctx context.Context) error {
|
||||
if s.server == nil {
|
||||
return nil
|
||||
}
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) Check(ctx context.Context) map[string]string {
|
||||
status := map[string]string{"status": "ok"}
|
||||
if s.db != nil {
|
||||
if err := s.db.PingContext(ctx); err != nil {
|
||||
status["status"] = "fail"
|
||||
status["db"] = err.Error()
|
||||
} else {
|
||||
status["db"] = "ok"
|
||||
}
|
||||
}
|
||||
if s.gateway != nil {
|
||||
serverTime, err := s.gateway.GetServerTime(ctx)
|
||||
if err != nil {
|
||||
status["status"] = "fail"
|
||||
status["api"] = err.Error()
|
||||
} else {
|
||||
status["api"] = "ok"
|
||||
drift := timeutil.Drift(time.Now().UTC(), serverTime)
|
||||
status["clock_drift"] = drift.String()
|
||||
if s.maxDrift > 0 && drift > s.maxDrift {
|
||||
status["status"] = "fail"
|
||||
status["clock"] = fmt.Sprintf("drift %s exceeds %s", drift, s.maxDrift)
|
||||
}
|
||||
}
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
func (s *Service) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Service) handleReady(w http.ResponseWriter, r *http.Request) {
|
||||
status := s.Check(r.Context())
|
||||
code := http.StatusOK
|
||||
if status["status"] != "ok" {
|
||||
code = http.StatusServiceUnavailable
|
||||
}
|
||||
writeJSON(w, code, status)
|
||||
}
|
||||
|
||||
func CheckEndpoint(ctx context.Context, url string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
if resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("healthcheck returned %s", resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, code int, value any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_ = json.NewEncoder(w).Encode(value)
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package instruments
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
repo repository.Repository
|
||||
gateway tinvest.Gateway
|
||||
}
|
||||
|
||||
func NewRegistry(repo repository.Repository, gateway tinvest.Gateway) Registry {
|
||||
return Registry{repo: repo, gateway: gateway}
|
||||
}
|
||||
|
||||
func (r Registry) SyncMetadata(ctx context.Context) error {
|
||||
instruments, err := r.repo.ListInstruments(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, instrument := range instruments {
|
||||
if strings.HasPrefix(instrument.InstrumentUID, "PENDING:") || !instrument.MetadataValid() {
|
||||
remote, err := r.gateway.GetInstrument(ctx, instrument.Ticker, instrument.ClassCode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync %s: %w", instrument.Ticker, err)
|
||||
}
|
||||
remote.Enabled = instrument.Enabled && remote.Enabled
|
||||
remote.FundType = instrument.FundType
|
||||
remote.ExpectedCommissionBpsPerSide = instrument.ExpectedCommissionBpsPerSide
|
||||
remote.FreeOrderLimitPerDay = instrument.FreeOrderLimitPerDay
|
||||
remote.Quarantine = instrument.Quarantine
|
||||
remote.QuarantineReason = instrument.QuarantineReason
|
||||
remote.ExcludeReason = instrument.ExcludeReason
|
||||
if err := r.repo.ReplaceInstrument(ctx, instrument.InstrumentUID, remote); err != nil {
|
||||
return fmt.Errorf("replace synced instrument %s: %w", instrument.Ticker, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckInstrument(instrument domain.Instrument, status domain.TradingStatus) error {
|
||||
switch {
|
||||
case !instrument.Enabled:
|
||||
return fmt.Errorf("%s disabled", instrument.Ticker)
|
||||
case instrument.Quarantine:
|
||||
return fmt.Errorf("%s quarantined: %s", instrument.Ticker, instrument.QuarantineReason)
|
||||
case !instrument.MetadataValid():
|
||||
return fmt.Errorf("%s invalid metadata", instrument.Ticker)
|
||||
case status != domain.TradingStatusNormal:
|
||||
return fmt.Errorf("%s trading status %s", instrument.Ticker, status)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func New(level string, out io.Writer) *slog.Logger {
|
||||
if out == nil {
|
||||
out = os.Stdout
|
||||
}
|
||||
var slogLevel slog.Level
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
slogLevel = slog.LevelDebug
|
||||
case "warn", "warning":
|
||||
slogLevel = slog.LevelWarn
|
||||
case "error":
|
||||
slogLevel = slog.LevelError
|
||||
default:
|
||||
slogLevel = slog.LevelInfo
|
||||
}
|
||||
return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{Level: slogLevel}))
|
||||
}
|
||||
|
||||
type SDKLogger struct {
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
func (l SDKLogger) Infof(template string, args ...any) {
|
||||
if l.Logger != nil {
|
||||
l.Logger.Info(template, "args", args)
|
||||
}
|
||||
}
|
||||
|
||||
func (l SDKLogger) Errorf(template string, args ...any) {
|
||||
if l.Logger != nil {
|
||||
l.Logger.Error(template, "args", args)
|
||||
}
|
||||
}
|
||||
|
||||
func (l SDKLogger) Fatalf(template string, args ...any) {
|
||||
if l.Logger != nil {
|
||||
l.Logger.Error(template, "args", args)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package marketdata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
type Loader struct {
|
||||
repo repository.Repository
|
||||
gateway tinvest.Gateway
|
||||
}
|
||||
|
||||
func NewLoader(repo repository.Repository, gateway tinvest.Gateway) Loader {
|
||||
return Loader{repo: repo, gateway: gateway}
|
||||
}
|
||||
|
||||
func (l Loader) BackfillDaily(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error {
|
||||
for _, instrument := range instruments {
|
||||
if !instrument.Enabled || instrument.Quarantine {
|
||||
continue
|
||||
}
|
||||
candles, err := l.gateway.GetCandles(ctx, instrument.InstrumentUID, "day", from, to)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load candles %s: %w", instrument.Ticker, err)
|
||||
}
|
||||
if err := l.repo.UpsertDailyCandles(ctx, candles); err != nil {
|
||||
return fmt.Errorf("persist candles %s: %w", instrument.Ticker, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l Loader) BackfillMinute(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error {
|
||||
for _, instrument := range instruments {
|
||||
if !instrument.Enabled || instrument.Quarantine {
|
||||
continue
|
||||
}
|
||||
candles, err := l.gateway.GetCandles(ctx, instrument.InstrumentUID, "minute", from, to)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load minute candles %s: %w", instrument.Ticker, err)
|
||||
}
|
||||
if err := l.repo.UpsertMinuteCandles(ctx, candles); err != nil {
|
||||
return fmt.Errorf("persist minute candles %s: %w", instrument.Ticker, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l Loader) LatestQuote(ctx context.Context, instrumentUID string, depth int32, maxAge time.Duration) (domain.OrderBook, error) {
|
||||
book, err := l.gateway.GetOrderBook(ctx, instrumentUID, depth)
|
||||
if err != nil {
|
||||
return domain.OrderBook{}, err
|
||||
}
|
||||
age := time.Since(book.ReceivedAt)
|
||||
if book.ReceivedAt.IsZero() {
|
||||
age = time.Since(book.Time)
|
||||
}
|
||||
if maxAge > 0 && age > maxAge {
|
||||
return domain.OrderBook{}, fmt.Errorf("quote age %s exceeds %s", age, maxAge)
|
||||
}
|
||||
return book, nil
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package money
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
pb "github.com/russianinvestments/invest-api-go-sdk/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidTick = errors.New("tick must be positive")
|
||||
ErrInvalidBase = errors.New("base must be positive")
|
||||
)
|
||||
|
||||
type RoundMode int
|
||||
|
||||
const (
|
||||
RoundNearest RoundMode = iota
|
||||
RoundFloor
|
||||
RoundCeil
|
||||
)
|
||||
|
||||
func QuotationToDecimal(q *pb.Quotation) decimal.Decimal {
|
||||
if q == nil {
|
||||
return decimal.Zero
|
||||
}
|
||||
return decimal.NewFromInt(q.GetUnits()).Add(decimal.New(int64(q.GetNano()), -9))
|
||||
}
|
||||
|
||||
func DecimalToQuotation(d decimal.Decimal) *pb.Quotation {
|
||||
units := d.Truncate(0)
|
||||
nano := d.Sub(units).Mul(decimal.NewFromInt(1_000_000_000)).Round(0)
|
||||
if nano.Equal(decimal.NewFromInt(1_000_000_000)) {
|
||||
units = units.Add(decimal.NewFromInt(1))
|
||||
nano = decimal.Zero
|
||||
}
|
||||
if nano.Equal(decimal.NewFromInt(-1_000_000_000)) {
|
||||
units = units.Sub(decimal.NewFromInt(1))
|
||||
nano = decimal.Zero
|
||||
}
|
||||
nanoPart := nano.IntPart()
|
||||
if nanoPart < -999_999_999 || nanoPart > 999_999_999 {
|
||||
panic("decimal quotation nano is out of protobuf range")
|
||||
}
|
||||
return &pb.Quotation{
|
||||
Units: units.IntPart(),
|
||||
Nano: int32(nanoPart), // #nosec G115 -- nanoPart is bounded above.
|
||||
}
|
||||
}
|
||||
|
||||
func MoneyValueToDecimal(v *pb.MoneyValue) decimal.Decimal {
|
||||
if v == nil {
|
||||
return decimal.Zero
|
||||
}
|
||||
return decimal.NewFromInt(v.GetUnits()).Add(decimal.New(int64(v.GetNano()), -9))
|
||||
}
|
||||
|
||||
func Bps(part, base decimal.Decimal) (decimal.Decimal, error) {
|
||||
if !base.IsPositive() {
|
||||
return decimal.Zero, ErrInvalidBase
|
||||
}
|
||||
return part.Div(base).Mul(decimal.NewFromInt(10_000)), nil
|
||||
}
|
||||
|
||||
func FromBps(bps decimal.Decimal) decimal.Decimal {
|
||||
return bps.Div(decimal.NewFromInt(10_000))
|
||||
}
|
||||
|
||||
func RoundToTick(price, tick decimal.Decimal, mode RoundMode) (decimal.Decimal, error) {
|
||||
if !tick.IsPositive() {
|
||||
return decimal.Zero, ErrInvalidTick
|
||||
}
|
||||
steps := price.Div(tick)
|
||||
switch mode {
|
||||
case RoundFloor:
|
||||
steps = steps.Floor()
|
||||
case RoundCeil:
|
||||
steps = steps.Ceil()
|
||||
default:
|
||||
steps = steps.Round(0)
|
||||
}
|
||||
return steps.Mul(tick), nil
|
||||
}
|
||||
|
||||
func Min(values ...decimal.Decimal) decimal.Decimal {
|
||||
if len(values) == 0 {
|
||||
return decimal.Zero
|
||||
}
|
||||
min := values[0]
|
||||
for _, value := range values[1:] {
|
||||
if value.LessThan(min) {
|
||||
min = value
|
||||
}
|
||||
}
|
||||
return min
|
||||
}
|
||||
|
||||
func Max(values ...decimal.Decimal) decimal.Decimal {
|
||||
if len(values) == 0 {
|
||||
return decimal.Zero
|
||||
}
|
||||
max := values[0]
|
||||
for _, value := range values[1:] {
|
||||
if value.GreaterThan(max) {
|
||||
max = value
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func Abs(value decimal.Decimal) decimal.Decimal {
|
||||
if value.IsNegative() {
|
||||
return value.Neg()
|
||||
}
|
||||
return value
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package money
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func d(raw string) decimal.Decimal {
|
||||
v, err := decimal.NewFromString(raw)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func TestRoundToTick(t *testing.T) {
|
||||
tests := []struct {
|
||||
price string
|
||||
tick string
|
||||
mode RoundMode
|
||||
want string
|
||||
}{
|
||||
{"10.12346", "0.0001", RoundNearest, "10.1235"},
|
||||
{"10.126", "0.01", RoundFloor, "10.12"},
|
||||
{"10.126", "0.01", RoundCeil, "10.13"},
|
||||
{"10.24", "0.5", RoundNearest, "10"},
|
||||
{"10.26", "0.5", RoundNearest, "10.5"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, err := RoundToTick(d(tt.price), d(tt.tick), tt.mode)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !got.Equal(d(tt.want)) {
|
||||
t.Fatalf("RoundToTick(%s,%s)=%s want %s", tt.price, tt.tick, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
type Notifier interface {
|
||||
Info(ctx context.Context, msg string) error
|
||||
Warn(ctx context.Context, msg string) error
|
||||
Alert(ctx context.Context, msg string) error
|
||||
Report(ctx context.Context, msg string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Noop struct{}
|
||||
|
||||
func (Noop) Info(context.Context, string) error { return nil }
|
||||
func (Noop) Warn(context.Context, string) error { return nil }
|
||||
func (Noop) Alert(context.Context, string) error { return nil }
|
||||
func (Noop) Report(context.Context, string) error { return nil }
|
||||
func (Noop) Close() error { return nil }
|
||||
|
||||
type TelegramConfig struct {
|
||||
BotToken string
|
||||
ChatID int64
|
||||
NotifyInfo bool
|
||||
NotifyWarn bool
|
||||
NotifyAlert bool
|
||||
NotifyReport bool
|
||||
AuditSink AuditSink
|
||||
}
|
||||
|
||||
type AuditSink interface {
|
||||
InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error
|
||||
}
|
||||
|
||||
type Telegram struct {
|
||||
cfg TelegramConfig
|
||||
bot *tgbotapi.BotAPI
|
||||
log *slog.Logger
|
||||
queue chan outbound
|
||||
done chan struct{}
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
type outbound struct {
|
||||
level domain.Severity
|
||||
text string
|
||||
}
|
||||
|
||||
func NewTelegram(cfg TelegramConfig, log *slog.Logger) (Notifier, error) {
|
||||
if cfg.BotToken == "" || cfg.ChatID == 0 {
|
||||
return Noop{}, nil
|
||||
}
|
||||
bot, err := tgbotapi.NewBotAPI(cfg.BotToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := &Telegram{
|
||||
cfg: cfg,
|
||||
bot: bot,
|
||||
log: log,
|
||||
queue: make(chan outbound, 256),
|
||||
done: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go t.dispatch()
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *Telegram) Info(ctx context.Context, msg string) error {
|
||||
if !t.cfg.NotifyInfo {
|
||||
return nil
|
||||
}
|
||||
return t.enqueue(ctx, domain.SeverityInfo, msg, false)
|
||||
}
|
||||
|
||||
func (t *Telegram) Warn(ctx context.Context, msg string) error {
|
||||
if !t.cfg.NotifyWarn {
|
||||
return nil
|
||||
}
|
||||
return t.enqueue(ctx, domain.SeverityWarn, msg, false)
|
||||
}
|
||||
|
||||
func (t *Telegram) Alert(ctx context.Context, msg string) error {
|
||||
if !t.cfg.NotifyAlert {
|
||||
return nil
|
||||
}
|
||||
return t.enqueue(ctx, domain.SeverityAlert, msg, true)
|
||||
}
|
||||
|
||||
func (t *Telegram) Report(ctx context.Context, msg string) error {
|
||||
if !t.cfg.NotifyReport {
|
||||
return nil
|
||||
}
|
||||
return t.enqueueText(ctx, domain.SeverityInfo, formatMessage("[REPORT]", msg), true)
|
||||
}
|
||||
|
||||
func (t *Telegram) Close() error {
|
||||
close(t.done)
|
||||
<-t.closed
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Telegram) enqueue(ctx context.Context, level domain.Severity, msg string, mustDeliver bool) error {
|
||||
return t.enqueueText(ctx, level, formatMessage(prefix(level), msg), mustDeliver)
|
||||
}
|
||||
|
||||
func (t *Telegram) enqueueText(ctx context.Context, level domain.Severity, text string, mustDeliver bool) error {
|
||||
item := outbound{level: level, text: text}
|
||||
if mustDeliver {
|
||||
select {
|
||||
case t.queue <- item:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
select {
|
||||
case t.queue <- item:
|
||||
default:
|
||||
if t.log != nil {
|
||||
t.log.Warn("telegram queue full; dropping non-critical notification", "level", level)
|
||||
}
|
||||
if t.cfg.AuditSink != nil {
|
||||
_ = t.cfg.AuditSink.InsertRiskEvent(ctx, domain.RiskEvent{
|
||||
TS: time.Now().UTC(),
|
||||
Severity: domain.SeverityWarn,
|
||||
EventType: "notification_dropped",
|
||||
Message: fmt.Sprintf("telegram queue full; dropped %s notification", level),
|
||||
ContextJSON: "{}",
|
||||
})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Telegram) dispatch() {
|
||||
defer close(t.closed)
|
||||
for {
|
||||
select {
|
||||
case item := <-t.queue:
|
||||
t.send(item)
|
||||
case <-t.done:
|
||||
for {
|
||||
select {
|
||||
case item := <-t.queue:
|
||||
t.send(item)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Telegram) send(item outbound) {
|
||||
msg := tgbotapi.NewMessage(t.cfg.ChatID, item.text)
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if _, err := t.bot.Send(msg); err != nil {
|
||||
delay := telegramRetryDelay(err, attempt)
|
||||
if t.log != nil {
|
||||
t.log.Warn("telegram send failed", "attempt", attempt+1, "err", err, "retry_in", delay)
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-t.done:
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func telegramRetryDelay(err error, attempt int) time.Duration {
|
||||
var apiErr tgbotapi.Error
|
||||
if errors.As(err, &apiErr) && apiErr.RetryAfter > 0 {
|
||||
return time.Duration(apiErr.RetryAfter) * time.Second
|
||||
}
|
||||
var apiErrPtr *tgbotapi.Error
|
||||
if errors.As(err, &apiErrPtr) && apiErrPtr != nil && apiErrPtr.RetryAfter > 0 {
|
||||
return time.Duration(apiErrPtr.RetryAfter) * time.Second
|
||||
}
|
||||
return time.Duration(attempt+1) * time.Second
|
||||
}
|
||||
|
||||
func prefix(level domain.Severity) string {
|
||||
switch level {
|
||||
case domain.SeverityInfo:
|
||||
return "[INFO]"
|
||||
case domain.SeverityWarn:
|
||||
return "[WARN]"
|
||||
case domain.SeverityAlert:
|
||||
return "[ALERT]"
|
||||
default:
|
||||
return fmt.Sprintf("[%s]", strings.ToUpper(string(level)))
|
||||
}
|
||||
}
|
||||
|
||||
func formatMessage(prefixValue, msg string) string {
|
||||
return prefixValue + " " + msg
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func TestPrefix(t *testing.T) {
|
||||
tests := map[domain.Severity]string{
|
||||
domain.SeverityInfo: "[INFO]",
|
||||
domain.SeverityWarn: "[WARN]",
|
||||
domain.SeverityAlert: "[ALERT]",
|
||||
domain.Severity("alert"): "[ALERT]",
|
||||
}
|
||||
for severity, want := range tests {
|
||||
if got := prefix(severity); got != want {
|
||||
t.Fatalf("prefix(%s)=%s, want %s", severity, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatReportPrefix(t *testing.T) {
|
||||
if got := formatMessage("[REPORT]", "daily"); got != "[REPORT] daily" {
|
||||
t.Fatalf("message=%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTelegramRetryDelayUsesRetryAfter(t *testing.T) {
|
||||
err := &tgbotapi.Error{
|
||||
Code: 429,
|
||||
ResponseParameters: tgbotapi.ResponseParameters{
|
||||
RetryAfter: 7,
|
||||
},
|
||||
}
|
||||
if got := telegramRetryDelay(err, 0); got != 7*time.Second {
|
||||
t.Fatalf("delay=%s, want 7s", got)
|
||||
}
|
||||
if got := telegramRetryDelay(assertErr{}, 1); got != 2*time.Second {
|
||||
t.Fatalf("fallback delay=%s, want 2s", got)
|
||||
}
|
||||
}
|
||||
|
||||
type assertErr struct{}
|
||||
|
||||
func (assertErr) Error() string { return "boom" }
|
||||
@@ -0,0 +1,93 @@
|
||||
package position
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/money"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
repo repository.Repository
|
||||
}
|
||||
|
||||
func NewManager(repo repository.Repository) Manager {
|
||||
return Manager{repo: repo}
|
||||
}
|
||||
|
||||
func (m Manager) OnEntryFill(ctx context.Context, accountIDHash string, instrument domain.Instrument, order domain.Order) (domain.Position, error) {
|
||||
now := time.Now().UTC()
|
||||
lot := instrument.Lot
|
||||
if lot <= 0 {
|
||||
lot = 1
|
||||
}
|
||||
pos := domain.Position{
|
||||
AccountIDHash: accountIDHash,
|
||||
InstrumentUID: order.InstrumentUID,
|
||||
OpenTradeDate: order.TradeDate,
|
||||
Lots: order.FilledLots,
|
||||
Lot: lot,
|
||||
AvgBuyPrice: order.AvgFillPrice,
|
||||
CommissionTotal: order.Commission,
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
OpenedAt: &now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if pos.Lots < order.QuantityLots {
|
||||
pos.Status = domain.PositionEntryPartiallyFilled
|
||||
}
|
||||
if err := m.repo.UpsertPosition(ctx, pos); err != nil {
|
||||
return domain.Position{}, err
|
||||
}
|
||||
return pos, nil
|
||||
}
|
||||
|
||||
func (m Manager) OnExitFill(ctx context.Context, pos domain.Position, exitOrder domain.Order) (domain.Position, error) {
|
||||
now := time.Now().UTC()
|
||||
lot := pos.Lot
|
||||
if lot <= 0 {
|
||||
lot = 1
|
||||
}
|
||||
executedLots := min(exitOrder.FilledLots, pos.Lots)
|
||||
if executedLots < 0 {
|
||||
executedLots = 0
|
||||
}
|
||||
previousExitLots := pos.ExitFilledLots
|
||||
pos.ExitFilledLots += executedLots
|
||||
if executedLots > 0 {
|
||||
previousValue := pos.AvgSellPrice.Mul(decimal.NewFromInt(previousExitLots))
|
||||
newValue := exitOrder.AvgFillPrice.Mul(decimal.NewFromInt(executedLots))
|
||||
pos.AvgSellPrice = previousValue.Add(newValue).Div(decimal.NewFromInt(pos.ExitFilledLots))
|
||||
}
|
||||
pos.CommissionTotal = pos.CommissionTotal.Add(exitOrder.Commission)
|
||||
executedUnits := decimal.NewFromInt(executedLots).Mul(decimal.NewFromInt(lot))
|
||||
pos.GrossPnL = pos.GrossPnL.Add(exitOrder.AvgFillPrice.Sub(pos.AvgBuyPrice).Mul(executedUnits))
|
||||
pos.NetPnL = pos.GrossPnL.Sub(pos.CommissionTotal)
|
||||
if pos.AvgBuyPrice.IsPositive() {
|
||||
baseLots := pos.ExitFilledLots
|
||||
if baseLots <= 0 {
|
||||
baseLots = pos.Lots
|
||||
}
|
||||
base := pos.AvgBuyPrice.Mul(decimal.NewFromInt(baseLots)).Mul(decimal.NewFromInt(lot))
|
||||
edge, _ := money.Bps(pos.NetPnL, base)
|
||||
pos.RealizedEdgeBps = edge
|
||||
}
|
||||
pos.Status = domain.PositionExitFilled
|
||||
if executedLots < pos.Lots {
|
||||
pos.Lots -= executedLots
|
||||
pos.Status = domain.PositionExitPartiallyFilled
|
||||
pos.ClosedAt = nil
|
||||
} else {
|
||||
pos.Lots = 0
|
||||
pos.ClosedAt = &now
|
||||
}
|
||||
pos.UpdatedAt = now
|
||||
if err := m.repo.UpsertPosition(ctx, pos); err != nil {
|
||||
return domain.Position{}, err
|
||||
}
|
||||
return pos, nil
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package position
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/testutil"
|
||||
)
|
||||
|
||||
func TestOnEntryFillKeepsBuyCommission(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
manager := NewManager(testutil.NewMemoryRepository())
|
||||
pos, err := manager.OnEntryFill(ctx, "hash", domain.Instrument{Lot: 1}, domain.Order{
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: time.Now().UTC(),
|
||||
QuantityLots: 10,
|
||||
FilledLots: 10,
|
||||
AvgFillPrice: decimal.NewFromInt(100),
|
||||
Commission: decimal.NewFromInt(3),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !pos.CommissionTotal.Equal(decimal.NewFromInt(3)) {
|
||||
t.Fatalf("commission=%s, want 3", pos.CommissionTotal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnExitFillPartialUsesExecutedLots(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
manager := NewManager(testutil.NewMemoryRepository())
|
||||
openAt := time.Now().UTC()
|
||||
pos := domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: openAt,
|
||||
Lots: 10,
|
||||
Lot: 1,
|
||||
AvgBuyPrice: decimal.NewFromInt(100),
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
CommissionTotal: decimal.NewFromInt(2),
|
||||
OpenedAt: &openAt,
|
||||
}
|
||||
updated, err := manager.OnExitFill(ctx, pos, domain.Order{
|
||||
InstrumentUID: "uid",
|
||||
FilledLots: 4,
|
||||
AvgFillPrice: decimal.NewFromInt(110),
|
||||
Commission: decimal.NewFromInt(1),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated.Status != domain.PositionExitPartiallyFilled || updated.ClosedAt != nil {
|
||||
t.Fatalf("unexpected partial status/closed_at: %+v", updated)
|
||||
}
|
||||
if updated.Lots != 6 {
|
||||
t.Fatalf("remaining lots=%d, want 6", updated.Lots)
|
||||
}
|
||||
if !updated.GrossPnL.Equal(decimal.NewFromInt(40)) {
|
||||
t.Fatalf("gross pnl=%s, want 40", updated.GrossPnL)
|
||||
}
|
||||
if updated.ExitFilledLots != 4 || !updated.AvgSellPrice.Equal(decimal.NewFromInt(110)) {
|
||||
t.Fatalf("exit aggregation lots=%d avg=%s", updated.ExitFilledLots, updated.AvgSellPrice)
|
||||
}
|
||||
second, err := manager.OnExitFill(ctx, updated, domain.Order{
|
||||
InstrumentUID: "uid",
|
||||
FilledLots: 3,
|
||||
AvgFillPrice: decimal.NewFromInt(120),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wantAvg := decimal.NewFromInt(800).Div(decimal.NewFromInt(7))
|
||||
if second.ExitFilledLots != 7 || !second.AvgSellPrice.Equal(wantAvg) {
|
||||
t.Fatalf("weighted avg sell=%s lots=%d, want %s/7", second.AvgSellPrice, second.ExitFilledLots, wantAvg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnExitFillUsesInstrumentLotForAbsolutePnL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
manager := NewManager(testutil.NewMemoryRepository())
|
||||
openAt := time.Now().UTC()
|
||||
pos := domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: openAt,
|
||||
Lots: 4,
|
||||
Lot: 10,
|
||||
AvgBuyPrice: decimal.NewFromInt(100),
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
CommissionTotal: decimal.NewFromInt(2),
|
||||
OpenedAt: &openAt,
|
||||
}
|
||||
updated, err := manager.OnExitFill(ctx, pos, domain.Order{
|
||||
InstrumentUID: "uid",
|
||||
FilledLots: 4,
|
||||
AvgFillPrice: decimal.NewFromInt(105),
|
||||
Commission: decimal.NewFromInt(3),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !updated.GrossPnL.Equal(decimal.NewFromInt(200)) {
|
||||
t.Fatalf("gross pnl=%s, want 200", updated.GrossPnL)
|
||||
}
|
||||
if !updated.NetPnL.Equal(decimal.NewFromInt(195)) {
|
||||
t.Fatalf("net pnl=%s, want 195", updated.NetPnL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnExitFillUsesLotInRealizedEdgeCommissionBase(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
manager := NewManager(testutil.NewMemoryRepository())
|
||||
openAt := time.Now().UTC()
|
||||
pos := domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: openAt,
|
||||
Lots: 1,
|
||||
Lot: 100,
|
||||
AvgBuyPrice: decimal.NewFromInt(100),
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
OpenedAt: &openAt,
|
||||
}
|
||||
updated, err := manager.OnExitFill(ctx, pos, domain.Order{
|
||||
InstrumentUID: "uid",
|
||||
FilledLots: 1,
|
||||
AvgFillPrice: decimal.NewFromInt(100),
|
||||
Commission: decimal.NewFromInt(10),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !updated.RealizedEdgeBps.Equal(decimal.NewFromInt(-10)) {
|
||||
t.Fatalf("realized edge=%s, want -10 bps", updated.RealizedEdgeBps)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package reconciliation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/money"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
type Engine struct {
|
||||
repo repository.Repository
|
||||
gateway tinvest.Gateway
|
||||
accountID string
|
||||
accountIDHash string
|
||||
window time.Duration
|
||||
inFlightGrace time.Duration
|
||||
}
|
||||
|
||||
func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine {
|
||||
return Engine{repo: repo, gateway: gateway, accountID: accountID, accountIDHash: accountIDHash, window: 72 * time.Hour}
|
||||
}
|
||||
|
||||
func (e Engine) WithWindow(window time.Duration) Engine {
|
||||
if window > 0 {
|
||||
e.window = window
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func (e Engine) WithInFlightGrace(grace time.Duration) Engine {
|
||||
if grace >= 0 {
|
||||
e.inFlightGrace = grace
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
|
||||
localOrders, err := e.repo.ListActiveOrders(ctx, e.accountIDHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
brokerOrders, err := e.gateway.GetActiveOrders(ctx, e.accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
localByBroker := make(map[string]domain.Order, len(localOrders))
|
||||
brokerByID := make(map[string]domain.Order, len(brokerOrders))
|
||||
for _, order := range localOrders {
|
||||
if order.BrokerOrderID != "" {
|
||||
localByBroker[order.BrokerOrderID] = order
|
||||
}
|
||||
}
|
||||
var diffs []domain.ReconciliationDiff
|
||||
for _, brokerOrder := range brokerOrders {
|
||||
brokerByID[brokerOrder.BrokerOrderID] = brokerOrder
|
||||
if _, ok := localByBroker[brokerOrder.BrokerOrderID]; !ok {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "unknown_active_order",
|
||||
InstrumentUID: brokerOrder.InstrumentUID,
|
||||
Message: fmt.Sprintf("broker order %s is not known locally", brokerOrder.BrokerOrderID),
|
||||
Critical: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, localOrder := range localOrders {
|
||||
if e.isInFlight(localOrder, now) {
|
||||
continue
|
||||
}
|
||||
if localOrder.BrokerOrderID == "" {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "local_order_without_broker_id",
|
||||
InstrumentUID: localOrder.InstrumentUID,
|
||||
Message: fmt.Sprintf("local order %s is active without broker order id", localOrder.ClientOrderID),
|
||||
Critical: true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if _, ok := brokerByID[localOrder.BrokerOrderID]; !ok {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "missing_local_order",
|
||||
InstrumentUID: localOrder.InstrumentUID,
|
||||
Message: fmt.Sprintf("local active order %s/%s is not active at broker", localOrder.ClientOrderID, localOrder.BrokerOrderID),
|
||||
Critical: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
localPositions, err := e.repo.ListOpenPositions(ctx, e.accountIDHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
portfolio, err := e.gateway.GetPortfolio(ctx, e.accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
brokerLots := make(map[string]int64, len(portfolio.Holdings))
|
||||
for _, holding := range portfolio.Holdings {
|
||||
brokerLots[holding.InstrumentUID] += holding.QuantityLots
|
||||
}
|
||||
for _, pos := range localPositions {
|
||||
if brokerLots[pos.InstrumentUID] != pos.Lots {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "position_lots_mismatch",
|
||||
InstrumentUID: pos.InstrumentUID,
|
||||
Message: fmt.Sprintf("local lots=%d broker lots=%d", pos.Lots, brokerLots[pos.InstrumentUID]),
|
||||
Critical: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
localLots := make(map[string]int64, len(localPositions))
|
||||
for _, pos := range localPositions {
|
||||
localLots[pos.InstrumentUID] += pos.Lots
|
||||
}
|
||||
for instrumentUID, lots := range brokerLots {
|
||||
if lots > 0 && localLots[instrumentUID] == 0 {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "unknown_broker_position",
|
||||
InstrumentUID: instrumentUID,
|
||||
Message: fmt.Sprintf("broker holds %d lots but local position is absent", lots),
|
||||
Critical: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
from := now.Add(-e.window)
|
||||
recentOrders, err := e.repo.ListOrders(ctx, e.accountIDHash, from, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
operations, err := e.gateway.GetOperations(ctx, e.accountID, from, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
diffs = append(diffs, compareOperations(recentOrders, operations)...)
|
||||
raw, _ := json.Marshal(diffs)
|
||||
if err := e.repo.InsertReconciliation(ctx, now, string(raw), len(diffs) > 0); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return diffs, nil
|
||||
}
|
||||
|
||||
func (e Engine) isInFlight(order domain.Order, now time.Time) bool {
|
||||
if e.inFlightGrace <= 0 || order.CreatedAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return order.CreatedAt.After(now.Add(-e.inFlightGrace))
|
||||
}
|
||||
|
||||
func HasCritical(diffs []domain.ReconciliationDiff) bool {
|
||||
for _, diff := range diffs {
|
||||
if diff.Critical {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func compareOperations(orders []domain.Order, operations []domain.Operation) []domain.ReconciliationDiff {
|
||||
var diffs []domain.ReconciliationDiff
|
||||
localCommissionByInstrument := make(map[string]decimal.Decimal)
|
||||
localTraded := make(map[string]bool)
|
||||
for _, order := range orders {
|
||||
if order.Status == domain.OrderStatusFilled || order.Status == domain.OrderStatusPartiallyFilled {
|
||||
localCommissionByInstrument[order.InstrumentUID] = localCommissionByInstrument[order.InstrumentUID].Add(order.Commission)
|
||||
localTraded[order.InstrumentUID] = true
|
||||
}
|
||||
}
|
||||
brokerCommissionByInstrument := make(map[string]decimal.Decimal)
|
||||
brokerTraded := make(map[string]bool)
|
||||
for _, op := range operations {
|
||||
if !op.Commission.IsZero() {
|
||||
brokerCommissionByInstrument[op.InstrumentUID] = brokerCommissionByInstrument[op.InstrumentUID].Add(op.Commission)
|
||||
}
|
||||
if isTradeOperation(op.Type) {
|
||||
brokerTraded[op.InstrumentUID] = true
|
||||
}
|
||||
}
|
||||
instruments := make(map[string]struct{}, len(localCommissionByInstrument)+len(brokerCommissionByInstrument))
|
||||
for instrumentUID := range localCommissionByInstrument {
|
||||
instruments[instrumentUID] = struct{}{}
|
||||
}
|
||||
for instrumentUID := range brokerCommissionByInstrument {
|
||||
instruments[instrumentUID] = struct{}{}
|
||||
}
|
||||
for instrumentUID := range instruments {
|
||||
localCommission := localCommissionByInstrument[instrumentUID]
|
||||
brokerCommission := brokerCommissionByInstrument[instrumentUID]
|
||||
if diff := money.Abs(localCommission.Sub(brokerCommission)); diff.GreaterThan(decimal.NewFromFloat(0.01)) {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "commission_mismatch",
|
||||
InstrumentUID: instrumentUID,
|
||||
Message: fmt.Sprintf("local commission=%s broker commission=%s", localCommission.StringFixed(2), brokerCommission.StringFixed(2)),
|
||||
Critical: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
for instrumentUID := range brokerTraded {
|
||||
if instrumentUID != "" && !localTraded[instrumentUID] {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "unknown_broker_operation",
|
||||
InstrumentUID: instrumentUID,
|
||||
Message: "broker has executed operation without local filled order",
|
||||
Critical: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
for instrumentUID := range localTraded {
|
||||
if !brokerTraded[instrumentUID] {
|
||||
diffs = append(diffs, domain.ReconciliationDiff{
|
||||
Kind: "missing_broker_operation",
|
||||
InstrumentUID: instrumentUID,
|
||||
Message: "local filled order has no matching broker operation in reconciliation window",
|
||||
Critical: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
return diffs
|
||||
}
|
||||
|
||||
func isTradeOperation(raw string) bool {
|
||||
raw = strings.ToUpper(raw)
|
||||
return strings.Contains(raw, "OPERATION_TYPE_BUY") || strings.Contains(raw, "OPERATION_TYPE_SELL")
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package reconciliation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/testutil"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
func TestReconciliationFindsCriticalDiffs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
now := time.Now().UTC()
|
||||
if err := repo.UpsertOrder(ctx, domain.Order{
|
||||
ClientOrderID: "local",
|
||||
BrokerOrderID: "broker-missing",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid-local",
|
||||
TradeDate: now,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusSent,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gateway.Orders["broker-unknown"] = domain.Order{
|
||||
ClientOrderID: "unknown",
|
||||
BrokerOrderID: "broker-unknown",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid-broker",
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusSent,
|
||||
}
|
||||
if err := repo.UpsertPosition(ctx, domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid-local",
|
||||
OpenTradeDate: now,
|
||||
Lots: 2,
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gateway.Portfolio = domain.Portfolio{
|
||||
Equity: decimal.NewFromInt(100000),
|
||||
Cash: decimal.NewFromInt(90000),
|
||||
Holdings: []domain.Holding{
|
||||
{InstrumentUID: "uid-local", QuantityLots: 1},
|
||||
{InstrumentUID: "uid-broker-only", QuantityLots: 3},
|
||||
},
|
||||
}
|
||||
diffs, err := New(repo, gateway, "account", "hash").Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wantKinds := map[string]bool{
|
||||
"unknown_active_order": false,
|
||||
"missing_local_order": false,
|
||||
"position_lots_mismatch": false,
|
||||
"unknown_broker_position": false,
|
||||
}
|
||||
for _, diff := range diffs {
|
||||
if _, ok := wantKinds[diff.Kind]; ok {
|
||||
wantKinds[diff.Kind] = true
|
||||
}
|
||||
}
|
||||
for kind, seen := range wantKinds {
|
||||
if !seen {
|
||||
t.Fatalf("missing diff kind %s in %+v", kind, diffs)
|
||||
}
|
||||
}
|
||||
if !HasCritical(diffs) {
|
||||
t.Fatalf("expected critical diffs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareOperationsCommissionPerInstrument(t *testing.T) {
|
||||
orders := []domain.Order{
|
||||
{InstrumentUID: "TRUR", Status: domain.OrderStatusFilled, Commission: decimal.NewFromInt(2)},
|
||||
{InstrumentUID: "TGLD", Status: domain.OrderStatusFilled, Commission: decimal.NewFromInt(1)},
|
||||
}
|
||||
operations := []domain.Operation{
|
||||
{InstrumentUID: "TRUR", Type: "OPERATION_TYPE_BUY", Commission: decimal.NewFromInt(1)},
|
||||
{InstrumentUID: "TGLD", Type: "OPERATION_TYPE_BUY", Commission: decimal.NewFromInt(2)},
|
||||
}
|
||||
diffs := compareOperations(orders, operations)
|
||||
seen := map[string]bool{}
|
||||
for _, diff := range diffs {
|
||||
if diff.Kind == "commission_mismatch" {
|
||||
seen[diff.InstrumentUID] = true
|
||||
}
|
||||
}
|
||||
if !seen["TRUR"] || !seen["TGLD"] {
|
||||
t.Fatalf("expected per-instrument commission diffs, got %+v", diffs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
now := time.Now().UTC()
|
||||
if err := repo.UpsertOrder(ctx, domain.Order{
|
||||
ClientOrderID: "fresh",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: now,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusSent,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
diffs, err := New(repo, gateway, "account", "hash").WithInFlightGrace(10 * time.Second).Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, diff := range diffs {
|
||||
if diff.Kind == "local_order_without_broker_id" || diff.Kind == "missing_local_order" {
|
||||
t.Fatalf("fresh in-flight order produced diff: %+v", diffs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
type DailyInput struct {
|
||||
Date time.Time
|
||||
Mode domain.Mode
|
||||
Signals []domain.Signal
|
||||
Positions []domain.Position
|
||||
AverageSpreadBps decimal.Decimal
|
||||
AverageSlipBps decimal.Decimal
|
||||
RiskStatus string
|
||||
}
|
||||
|
||||
func ComposeDaily(input DailyInput) string {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "Дата: %s\n", input.Date.Format("2006-01-02"))
|
||||
fmt.Fprintf(&b, "Режим: %s\n", input.Mode)
|
||||
fmt.Fprintf(&b, "Сигналы: %d\n", len(input.Signals))
|
||||
for _, signal := range input.Signals {
|
||||
fmt.Fprintf(&b, "- %s %s edge=%s reason=%s\n", signal.InstrumentUID, signal.Decision, signal.NetEdgeBps.StringFixed(2), signal.RejectReason)
|
||||
}
|
||||
gross := decimal.Zero
|
||||
net := decimal.Zero
|
||||
commission := decimal.Zero
|
||||
for _, pos := range input.Positions {
|
||||
gross = gross.Add(pos.GrossPnL)
|
||||
net = net.Add(pos.NetPnL)
|
||||
commission = commission.Add(pos.CommissionTotal)
|
||||
}
|
||||
fmt.Fprintf(&b, "Gross PnL: %s\n", gross.StringFixed(2))
|
||||
fmt.Fprintf(&b, "Net PnL: %s\n", net.StringFixed(2))
|
||||
fmt.Fprintf(&b, "Комиссии: %s\n", commission.StringFixed(2))
|
||||
fmt.Fprintf(&b, "Средний spread: %s bps\n", input.AverageSpreadBps.StringFixed(2))
|
||||
fmt.Fprintf(&b, "Среднее проскальзывание: %s bps\n", input.AverageSlipBps.StringFixed(2))
|
||||
fmt.Fprintf(&b, "Risk: %s", input.RiskStatus)
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
DROP TABLE IF EXISTS reconciliations;
|
||||
DROP TABLE IF EXISTS daily_reports;
|
||||
DROP TABLE IF EXISTS system_state;
|
||||
DROP TABLE IF EXISTS free_order_counters;
|
||||
DROP TABLE IF EXISTS risk_events;
|
||||
DROP TABLE IF EXISTS positions;
|
||||
DROP TABLE IF EXISTS orders;
|
||||
DROP TABLE IF EXISTS signals;
|
||||
DROP TABLE IF EXISTS features;
|
||||
DROP TABLE IF EXISTS candles_minute;
|
||||
DROP TABLE IF EXISTS candles_daily;
|
||||
DROP TABLE IF EXISTS instruments;
|
||||
DROP TABLE IF EXISTS schema_meta;
|
||||
@@ -0,0 +1,181 @@
|
||||
CREATE TABLE IF NOT EXISTS schema_meta (
|
||||
meta_key VARCHAR(64) PRIMARY KEY,
|
||||
meta_value VARCHAR(255) NOT NULL,
|
||||
updated_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3) ON UPDATE CURRENT_TIMESTAMP(3)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS instruments (
|
||||
instrument_uid VARCHAR(128) PRIMARY KEY,
|
||||
figi VARCHAR(64),
|
||||
ticker VARCHAR(32) NOT NULL,
|
||||
class_code VARCHAR(32) NOT NULL DEFAULT 'TQTF',
|
||||
name VARCHAR(255) NOT NULL DEFAULT '',
|
||||
lot BIGINT NOT NULL DEFAULT 1,
|
||||
min_price_increment DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
currency VARCHAR(8) NOT NULL DEFAULT 'RUB',
|
||||
enabled TINYINT(1) NOT NULL DEFAULT 1,
|
||||
fund_type VARCHAR(64) NOT NULL DEFAULT '',
|
||||
expected_commission_bps_per_side DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
free_order_limit_per_day INT NOT NULL DEFAULT 0 COMMENT '0 means no configured free-order cap',
|
||||
quarantine TINYINT(1) NOT NULL DEFAULT 0,
|
||||
quarantine_reason TEXT,
|
||||
exclude_reason TEXT,
|
||||
updated_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3) ON UPDATE CURRENT_TIMESTAMP(3),
|
||||
UNIQUE KEY ux_instruments_ticker_class (ticker, class_code)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS candles_daily (
|
||||
instrument_uid VARCHAR(128) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
open DECIMAL(20,8) NOT NULL,
|
||||
high DECIMAL(20,8) NOT NULL,
|
||||
low DECIMAL(20,8) NOT NULL,
|
||||
close DECIMAL(20,8) NOT NULL,
|
||||
volume_lots DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
source VARCHAR(32) NOT NULL,
|
||||
loaded_at DATETIME(3) NOT NULL,
|
||||
PRIMARY KEY (instrument_uid, trade_date),
|
||||
CONSTRAINT fk_candles_daily_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS candles_minute (
|
||||
instrument_uid VARCHAR(128) NOT NULL,
|
||||
ts DATETIME(3) NOT NULL,
|
||||
open DECIMAL(20,8) NOT NULL,
|
||||
high DECIMAL(20,8) NOT NULL,
|
||||
low DECIMAL(20,8) NOT NULL,
|
||||
close DECIMAL(20,8) NOT NULL,
|
||||
volume_lots DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
source VARCHAR(32) NOT NULL,
|
||||
loaded_at DATETIME(3) NOT NULL,
|
||||
PRIMARY KEY (instrument_uid, ts),
|
||||
CONSTRAINT fk_candles_minute_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS features (
|
||||
instrument_uid VARCHAR(128) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
r_on DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
r_day DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
mu_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
mu_on_252 DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
sigma_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
tstat_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
win_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
ewma_on DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
spread_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
half_spread_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
tick_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
adv_20 DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
expected_cost_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
net_edge_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
entry_interval_volume DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
exit_interval_volume DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
calculated_at DATETIME(3) NOT NULL,
|
||||
PRIMARY KEY (instrument_uid, trade_date),
|
||||
CONSTRAINT fk_features_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS signals (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
trade_date DATE NOT NULL,
|
||||
instrument_uid VARCHAR(128) NOT NULL,
|
||||
decision ENUM('ENTER','SKIP','REJECT') NOT NULL,
|
||||
score DECIMAL(20,10) NOT NULL DEFAULT 0,
|
||||
net_edge_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
target_notional DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
target_lots BIGINT NOT NULL DEFAULT 0,
|
||||
reject_reason VARCHAR(128),
|
||||
context_json JSON,
|
||||
created_at DATETIME(3) NOT NULL,
|
||||
UNIQUE KEY ux_signals_date_instr (trade_date, instrument_uid),
|
||||
CONSTRAINT fk_signals_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS orders (
|
||||
client_order_id VARCHAR(128) PRIMARY KEY,
|
||||
broker_order_id VARCHAR(128),
|
||||
account_id_hash VARCHAR(128) NOT NULL,
|
||||
instrument_uid VARCHAR(128) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
side ENUM('BUY','SELL') NOT NULL,
|
||||
order_type ENUM('LIMIT') NOT NULL,
|
||||
limit_price DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
quantity_lots BIGINT NOT NULL,
|
||||
filled_lots BIGINT NOT NULL DEFAULT 0,
|
||||
avg_fill_price DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
status ENUM('NEW','SENT','PARTIALLY_FILLED','FILLED','CANCELLED','REJECTED','EXPIRED','FAILED') NOT NULL,
|
||||
commission DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
attempt_no INT NOT NULL DEFAULT 1,
|
||||
raw_state_json JSON,
|
||||
created_at DATETIME(3) NOT NULL,
|
||||
updated_at DATETIME(3) NOT NULL,
|
||||
UNIQUE KEY ux_orders_broker_order_id (broker_order_id),
|
||||
KEY ix_orders_active (account_id_hash, status),
|
||||
CONSTRAINT fk_orders_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS positions (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
account_id_hash VARCHAR(128) NOT NULL,
|
||||
instrument_uid VARCHAR(128) NOT NULL,
|
||||
open_trade_date DATE NOT NULL,
|
||||
lots BIGINT NOT NULL,
|
||||
avg_buy_price DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
avg_sell_price DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
status ENUM('NO_POSITION','ENTRY_SIGNALLED','ENTRY_ORDER_SENT','ENTRY_PARTIALLY_FILLED','ENTRY_FILLED','HOLDING_OVERNIGHT','EXIT_ORDER_SENT','EXIT_PARTIALLY_FILLED','EXIT_FILLED','EXIT_FAILED','QUARANTINE') NOT NULL,
|
||||
gross_pnl DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
net_pnl DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
commission_total DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
realized_edge_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
|
||||
opened_at DATETIME(3),
|
||||
closed_at DATETIME(3),
|
||||
updated_at DATETIME(3) NOT NULL,
|
||||
KEY ix_positions_open (account_id_hash, status),
|
||||
CONSTRAINT fk_positions_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS risk_events (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
ts DATETIME(3) NOT NULL,
|
||||
severity ENUM('INFO','WARN','ALERT','CRITICAL') NOT NULL,
|
||||
event_type VARCHAR(128) NOT NULL,
|
||||
instrument_uid VARCHAR(128),
|
||||
message TEXT NOT NULL,
|
||||
raw_context_json JSON,
|
||||
KEY ix_risk_events_ts (ts)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS free_order_counters (
|
||||
trade_date DATE NOT NULL,
|
||||
instrument_uid VARCHAR(128) NOT NULL,
|
||||
orders_sent INT NOT NULL DEFAULT 0,
|
||||
updated_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3) ON UPDATE CURRENT_TIMESTAMP(3),
|
||||
PRIMARY KEY (trade_date, instrument_uid),
|
||||
CONSTRAINT fk_free_orders_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS system_state (
|
||||
id TINYINT NOT NULL PRIMARY KEY,
|
||||
state ENUM('INIT','SYNC_INSTRUMENTS','SYNC_MARKET_DATA','GENERATE_SIGNALS','WAIT_ENTRY_WINDOW','PLACE_ENTRY_ORDERS','MONITOR_ENTRY_ORDERS','HOLD_OVERNIGHT','WAIT_EXIT_WINDOW','PLACE_EXIT_ORDERS','MONITOR_EXIT_ORDERS','RECONCILE','REPORT','SLEEP','HALTED') NOT NULL,
|
||||
mode ENUM('backtest','paper','sandbox','live_readonly','live_trade') NOT NULL,
|
||||
halted TINYINT(1) NOT NULL DEFAULT 0,
|
||||
halt_reason TEXT,
|
||||
last_heartbeat DATETIME(3) NOT NULL,
|
||||
context_json JSON
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS reconciliations (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
ts DATETIME(3) NOT NULL,
|
||||
has_diff TINYINT(1) NOT NULL,
|
||||
diff_json JSON,
|
||||
KEY ix_reconciliations_ts (ts)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
INSERT INTO schema_meta(meta_key, meta_value) VALUES ('schema_version', '0001')
|
||||
ON DUPLICATE KEY UPDATE meta_value=VALUES(meta_value);
|
||||
|
||||
INSERT INTO system_state(id, state, mode, halted, last_heartbeat, context_json)
|
||||
VALUES (1, 'INIT', 'paper', 0, UTC_TIMESTAMP(3), JSON_OBJECT())
|
||||
ON DUPLICATE KEY UPDATE id=id;
|
||||
@@ -0,0 +1,7 @@
|
||||
DELETE FROM instruments WHERE instrument_uid IN (
|
||||
'PENDING:TRUR','PENDING:TGLD','PENDING:TBRU','PENDING:TDIV','PENDING:TMON',
|
||||
'PENDING:TOFZ','PENDING:TLCB','PENDING:TITR','PENDING:TRND','PENDING:TMOS'
|
||||
);
|
||||
|
||||
UPDATE schema_meta SET meta_value='0001' WHERE meta_key='schema_version';
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
INSERT INTO instruments (
|
||||
instrument_uid, ticker, class_code, name, lot, min_price_increment, currency,
|
||||
enabled, fund_type, expected_commission_bps_per_side, free_order_limit_per_day,
|
||||
quarantine, exclude_reason, updated_at
|
||||
) VALUES
|
||||
('PENDING:TRUR', 'TRUR', 'TQTF', 'TRUR', 1, 0.0001, 'RUB', 1, 'mixed', 0, 15, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TGLD', 'TGLD', 'TQTF', 'TGLD', 1, 0.0001, 'RUB', 1, 'commodity', 0, 15, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TBRU', 'TBRU', 'TQTF', 'TBRU', 1, 0.0001, 'RUB', 1, 'bonds', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TDIV', 'TDIV', 'TQTF', 'TDIV', 1, 0.0001, 'RUB', 1, 'equity_income', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TMON', 'TMON', 'TQTF', 'TMON', 1, 0.0001, 'RUB', 1, 'money_market', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TOFZ', 'TOFZ', 'TQTF', 'TOFZ', 1, 0.0001, 'RUB', 1, 'bonds', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TLCB', 'TLCB', 'TQTF', 'TLCB', 1, 0.0001, 'RUB', 1, 'corporate_bonds', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TITR', 'TITR', 'TQTF', 'TITR', 1, 0.0001, 'RUB', 1, 'equity', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TRND', 'TRND', 'TQTF', 'TRND', 1, 0.0001, 'RUB', 1, 'equity', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
|
||||
('PENDING:TMOS', 'TMOS', 'TQTF', 'TMOS', 1, 0.0001, 'RUB', 0, 'equity', 0, 0, 0, 'Excluded by default due to possible non-zero sell-side fee', UTC_TIMESTAMP(3))
|
||||
ON DUPLICATE KEY UPDATE
|
||||
enabled=VALUES(enabled),
|
||||
fund_type=VALUES(fund_type),
|
||||
expected_commission_bps_per_side=VALUES(expected_commission_bps_per_side),
|
||||
free_order_limit_per_day=VALUES(free_order_limit_per_day),
|
||||
exclude_reason=VALUES(exclude_reason),
|
||||
updated_at=UTC_TIMESTAMP(3);
|
||||
|
||||
UPDATE schema_meta SET meta_value='0002' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,29 @@
|
||||
ALTER TABLE free_order_counters DROP FOREIGN KEY fk_free_orders_instrument;
|
||||
ALTER TABLE free_order_counters ADD CONSTRAINT fk_free_orders_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
|
||||
|
||||
ALTER TABLE positions DROP FOREIGN KEY fk_positions_instrument;
|
||||
ALTER TABLE positions ADD CONSTRAINT fk_positions_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
|
||||
|
||||
ALTER TABLE orders DROP FOREIGN KEY fk_orders_instrument;
|
||||
ALTER TABLE orders ADD CONSTRAINT fk_orders_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
|
||||
|
||||
ALTER TABLE signals DROP FOREIGN KEY fk_signals_instrument;
|
||||
ALTER TABLE signals ADD CONSTRAINT fk_signals_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
|
||||
|
||||
ALTER TABLE features DROP FOREIGN KEY fk_features_instrument;
|
||||
ALTER TABLE features ADD CONSTRAINT fk_features_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
|
||||
|
||||
ALTER TABLE candles_minute DROP FOREIGN KEY fk_candles_minute_instrument;
|
||||
ALTER TABLE candles_minute ADD CONSTRAINT fk_candles_minute_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
|
||||
|
||||
ALTER TABLE candles_daily DROP FOREIGN KEY fk_candles_daily_instrument;
|
||||
ALTER TABLE candles_daily ADD CONSTRAINT fk_candles_daily_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
|
||||
|
||||
UPDATE schema_meta SET meta_value='0002' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,29 @@
|
||||
ALTER TABLE candles_daily DROP FOREIGN KEY fk_candles_daily_instrument;
|
||||
ALTER TABLE candles_daily ADD CONSTRAINT fk_candles_daily_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE candles_minute DROP FOREIGN KEY fk_candles_minute_instrument;
|
||||
ALTER TABLE candles_minute ADD CONSTRAINT fk_candles_minute_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE features DROP FOREIGN KEY fk_features_instrument;
|
||||
ALTER TABLE features ADD CONSTRAINT fk_features_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE signals DROP FOREIGN KEY fk_signals_instrument;
|
||||
ALTER TABLE signals ADD CONSTRAINT fk_signals_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE orders DROP FOREIGN KEY fk_orders_instrument;
|
||||
ALTER TABLE orders ADD CONSTRAINT fk_orders_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE positions DROP FOREIGN KEY fk_positions_instrument;
|
||||
ALTER TABLE positions ADD CONSTRAINT fk_positions_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE free_order_counters DROP FOREIGN KEY fk_free_orders_instrument;
|
||||
ALTER TABLE free_order_counters ADD CONSTRAINT fk_free_orders_instrument
|
||||
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
|
||||
|
||||
UPDATE schema_meta SET meta_value='0003' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,5 @@
|
||||
DROP TABLE IF EXISTS daily_reports;
|
||||
ALTER TABLE positions DROP INDEX ux_positions_trade;
|
||||
ALTER TABLE positions DROP COLUMN exit_filled_lots;
|
||||
|
||||
UPDATE schema_meta SET meta_value='0003' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,11 @@
|
||||
ALTER TABLE positions ADD COLUMN exit_filled_lots BIGINT NOT NULL DEFAULT 0 AFTER lots;
|
||||
ALTER TABLE positions ADD UNIQUE KEY ux_positions_trade (account_id_hash, instrument_uid, open_trade_date);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS daily_reports (
|
||||
report_date DATE NOT NULL,
|
||||
account_id_hash VARCHAR(128) NOT NULL,
|
||||
sent_at DATETIME(3) NOT NULL,
|
||||
PRIMARY KEY (report_date, account_id_hash)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
UPDATE schema_meta SET meta_value='0004' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,7 @@
|
||||
ALTER TABLE risk_events
|
||||
MODIFY severity ENUM('INFO','WARN','ALERT','CRITICAL','REPORT') NOT NULL;
|
||||
|
||||
ALTER TABLE instruments
|
||||
MODIFY free_order_limit_per_day INT NOT NULL DEFAULT 0;
|
||||
|
||||
UPDATE schema_meta SET meta_value='0004' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,20 @@
|
||||
UPDATE instruments
|
||||
SET free_order_limit_per_day=0
|
||||
WHERE ticker NOT IN ('TRUR', 'TGLD') AND free_order_limit_per_day=15;
|
||||
|
||||
ALTER TABLE instruments
|
||||
MODIFY free_order_limit_per_day INT NOT NULL DEFAULT 0 COMMENT '0 means no configured free-order cap';
|
||||
|
||||
UPDATE risk_events
|
||||
SET
|
||||
severity='INFO',
|
||||
event_type=CASE
|
||||
WHEN event_type LIKE 'report_%' THEN event_type
|
||||
ELSE CONCAT('report_', event_type)
|
||||
END
|
||||
WHERE severity='REPORT';
|
||||
|
||||
ALTER TABLE risk_events
|
||||
MODIFY severity ENUM('INFO','WARN','ALERT','CRITICAL') NOT NULL;
|
||||
|
||||
UPDATE schema_meta SET meta_value='0005' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE positions DROP COLUMN lot_size;
|
||||
|
||||
UPDATE schema_meta SET meta_value='0005' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,8 @@
|
||||
ALTER TABLE positions ADD COLUMN lot_size BIGINT NOT NULL DEFAULT 1 AFTER lots;
|
||||
|
||||
UPDATE positions p
|
||||
JOIN instruments i ON i.instrument_uid = p.instrument_uid
|
||||
SET p.lot_size = i.lot
|
||||
WHERE p.lot_size = 1 AND i.lot > 1;
|
||||
|
||||
UPDATE schema_meta SET meta_value='0006' WHERE meta_key='schema_version';
|
||||
@@ -0,0 +1,8 @@
|
||||
package migrations
|
||||
|
||||
import "embed"
|
||||
|
||||
// FS contains SQL migrations used by both the daemon and cmd/migrate.
|
||||
//
|
||||
//go:embed *.sql
|
||||
var FS embed.FS
|
||||
@@ -0,0 +1,61 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
migratemysql "github.com/golang-migrate/migrate/v4/database/mysql"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
|
||||
"overnight-trading-bot/internal/repository/migrations"
|
||||
)
|
||||
|
||||
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
driver, err := migratemysql.WithInstance(db, &migratemysql.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mysql migration driver: %w", err)
|
||||
}
|
||||
source, err := iofs.New(migrations.FS, ".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iofs migration source: %w", err)
|
||||
}
|
||||
m, err := migrate.NewWithInstance("iofs", source, "mysql", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create migrate instance: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = m.Close()
|
||||
}()
|
||||
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("apply migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RollbackAll(db *sql.DB) error {
|
||||
driver, err := migratemysql.WithInstance(db, &migratemysql.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mysql migration driver: %w", err)
|
||||
}
|
||||
source, err := iofs.New(migrations.FS, ".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iofs migration source: %w", err)
|
||||
}
|
||||
m, err := migrate.NewWithInstance("iofs", source, "mysql", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create migrate instance: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = m.Close()
|
||||
}()
|
||||
if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("rollback migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,790 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
)
|
||||
|
||||
var _ repository.Repository = (*Repository)(nil)
|
||||
|
||||
type Repository struct {
|
||||
db *sqlx.DB
|
||||
tx *sqlx.Tx
|
||||
}
|
||||
|
||||
func NewRepository(db *sqlx.DB) *Repository {
|
||||
return &Repository{db: db}
|
||||
}
|
||||
|
||||
func (r *Repository) RunInTx(ctx context.Context, fn func(ctx context.Context, repo repository.Repository) error) error {
|
||||
if r.tx != nil {
|
||||
return fn(ctx, r)
|
||||
}
|
||||
tx, err := r.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := &Repository{db: r.db, tx: tx}
|
||||
if err := fn(ctx, txRepo); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("%w; rollback: %v", err, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *Repository) execer() sqlx.ExtContext {
|
||||
if r.tx != nil {
|
||||
return r.tx
|
||||
}
|
||||
return r.db
|
||||
}
|
||||
|
||||
func (r *Repository) selectContext(ctx context.Context, dest any, query string, args ...any) error {
|
||||
if r.tx != nil {
|
||||
return r.tx.SelectContext(ctx, dest, query, args...)
|
||||
}
|
||||
return r.db.SelectContext(ctx, dest, query, args...)
|
||||
}
|
||||
|
||||
func (r *Repository) getContext(ctx context.Context, dest any, query string, args ...any) error {
|
||||
if r.tx != nil {
|
||||
return r.tx.GetContext(ctx, dest, query, args...)
|
||||
}
|
||||
return r.db.GetContext(ctx, dest, query, args...)
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertInstrument(ctx context.Context, instrument domain.Instrument) error {
|
||||
if instrument.UpdatedAt.IsZero() {
|
||||
instrument.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO instruments (
|
||||
instrument_uid, figi, ticker, class_code, name, lot, min_price_increment, currency,
|
||||
enabled, fund_type, expected_commission_bps_per_side, free_order_limit_per_day,
|
||||
quarantine, quarantine_reason, exclude_reason, updated_at
|
||||
) VALUES (
|
||||
:instrument_uid, :figi, :ticker, :class_code, :name, :lot, :min_price_increment, :currency,
|
||||
:enabled, :fund_type, :expected_commission_bps_per_side, :free_order_limit_per_day,
|
||||
:quarantine, :quarantine_reason, :exclude_reason, :updated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
instrument_uid=VALUES(instrument_uid),
|
||||
figi=VALUES(figi),
|
||||
name=VALUES(name),
|
||||
lot=VALUES(lot),
|
||||
min_price_increment=VALUES(min_price_increment),
|
||||
currency=VALUES(currency),
|
||||
enabled=VALUES(enabled),
|
||||
fund_type=VALUES(fund_type),
|
||||
expected_commission_bps_per_side=VALUES(expected_commission_bps_per_side),
|
||||
free_order_limit_per_day=VALUES(free_order_limit_per_day),
|
||||
quarantine=VALUES(quarantine),
|
||||
quarantine_reason=VALUES(quarantine_reason),
|
||||
exclude_reason=VALUES(exclude_reason),
|
||||
updated_at=VALUES(updated_at)`, instrumentRowFromDomain(instrument))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ReplaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
|
||||
if oldInstrumentUID == "" || oldInstrumentUID == instrument.InstrumentUID {
|
||||
return r.UpsertInstrument(ctx, instrument)
|
||||
}
|
||||
return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
|
||||
txRepo, ok := repo.(*Repository)
|
||||
if !ok {
|
||||
return errors.New("unexpected repository implementation")
|
||||
}
|
||||
return txRepo.replaceInstrument(ctx, oldInstrumentUID, instrument)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Repository) replaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
|
||||
if instrument.UpdatedAt.IsZero() {
|
||||
instrument.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
exists, err := r.instrumentExists(ctx, instrument.InstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
if err := r.mergeInstrumentUID(ctx, oldInstrumentUID, instrument.InstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.UpsertInstrument(ctx, instrument)
|
||||
}
|
||||
result, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
UPDATE instruments SET
|
||||
instrument_uid=:instrument_uid,
|
||||
figi=:figi,
|
||||
ticker=:ticker,
|
||||
class_code=:class_code,
|
||||
name=:name,
|
||||
lot=:lot,
|
||||
min_price_increment=:min_price_increment,
|
||||
currency=:currency,
|
||||
enabled=:enabled,
|
||||
fund_type=:fund_type,
|
||||
expected_commission_bps_per_side=:expected_commission_bps_per_side,
|
||||
free_order_limit_per_day=:free_order_limit_per_day,
|
||||
quarantine=:quarantine,
|
||||
quarantine_reason=:quarantine_reason,
|
||||
exclude_reason=:exclude_reason,
|
||||
updated_at=:updated_at
|
||||
WHERE instrument_uid=:old_instrument_uid`, replaceInstrumentRowFromDomain(oldInstrumentUID, instrument))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return r.UpsertInstrument(ctx, instrument)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) instrumentExists(ctx context.Context, instrumentUID string) (bool, error) {
|
||||
var count int
|
||||
if err := r.getContext(ctx, &count, `SELECT COUNT(*) FROM instruments WHERE instrument_uid=?`, instrumentUID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *Repository) mergeInstrumentUID(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
if oldInstrumentUID == newInstrumentUID {
|
||||
return nil
|
||||
}
|
||||
if err := r.mergeDailyCandles(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeMinuteCandles(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeFeatures(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeSignals(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeFreeOrders(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, table := range []string{"orders", "positions", "risk_events"} {
|
||||
if _, err := r.execer().ExecContext(ctx, fmt.Sprintf(`UPDATE %s SET instrument_uid=? WHERE instrument_uid=?`, table), newInstrumentUID, oldInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `DELETE FROM instruments WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeDailyCandles(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO candles_daily (instrument_uid, trade_date, open, high, low, close, volume_lots, source, loaded_at)
|
||||
SELECT ?, trade_date, open, high, low, close, volume_lots, source, loaded_at
|
||||
FROM candles_daily WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM candles_daily WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeMinuteCandles(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO candles_minute (instrument_uid, ts, open, high, low, close, volume_lots, source, loaded_at)
|
||||
SELECT ?, ts, open, high, low, close, volume_lots, source, loaded_at
|
||||
FROM candles_minute WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM candles_minute WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeFeatures(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO features (
|
||||
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
|
||||
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
|
||||
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
|
||||
exit_interval_volume, calculated_at
|
||||
)
|
||||
SELECT
|
||||
?, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
|
||||
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
|
||||
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
|
||||
exit_interval_volume, calculated_at
|
||||
FROM features WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
|
||||
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
|
||||
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
|
||||
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
|
||||
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
|
||||
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
|
||||
net_edge_bps=VALUES(net_edge_bps), entry_interval_volume=VALUES(entry_interval_volume),
|
||||
exit_interval_volume=VALUES(exit_interval_volume), calculated_at=VALUES(calculated_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM features WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeSignals(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO signals (
|
||||
trade_date, instrument_uid, decision, score, net_edge_bps, target_notional,
|
||||
target_lots, reject_reason, context_json, created_at
|
||||
)
|
||||
SELECT trade_date, ?, decision, score, net_edge_bps, target_notional,
|
||||
target_lots, reject_reason, context_json, created_at
|
||||
FROM signals WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
decision=VALUES(decision), score=VALUES(score), net_edge_bps=VALUES(net_edge_bps),
|
||||
target_notional=VALUES(target_notional), target_lots=VALUES(target_lots),
|
||||
reject_reason=VALUES(reject_reason), context_json=VALUES(context_json),
|
||||
created_at=VALUES(created_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM signals WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeFreeOrders(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO free_order_counters (trade_date, instrument_uid, orders_sent)
|
||||
SELECT trade_date, ?, orders_sent FROM free_order_counters WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE orders_sent=GREATEST(orders_sent, VALUES(orders_sent))`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM free_order_counters WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListInstruments(ctx context.Context, includeDisabled bool) ([]domain.Instrument, error) {
|
||||
query := `SELECT * FROM instruments`
|
||||
if !includeDisabled {
|
||||
query += ` WHERE enabled=1`
|
||||
}
|
||||
query += ` ORDER BY ticker`
|
||||
var rows []instrumentRow
|
||||
if err := r.selectContext(ctx, &rows, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Instrument, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) QuarantineInstrument(ctx context.Context, instrumentUID, reason string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
UPDATE instruments SET quarantine=1, quarantine_reason=?, updated_at=UTC_TIMESTAMP(3)
|
||||
WHERE instrument_uid=?`, reason, instrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertDailyCandles(ctx context.Context, candles []domain.Candle) error {
|
||||
for _, candle := range candles {
|
||||
if candle.LoadedAt.IsZero() {
|
||||
candle.LoadedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO candles_daily (
|
||||
instrument_uid, trade_date, open, high, low, close, volume_lots, source, loaded_at
|
||||
) VALUES (
|
||||
:instrument_uid, :trade_date, :open, :high, :low, :close, :volume_lots, :source, :loaded_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, candleRowFromDomain(candle))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListDailyCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
|
||||
var rows []candleRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM candles_daily
|
||||
WHERE instrument_uid=? AND trade_date BETWEEN ? AND ?
|
||||
ORDER BY trade_date`, instrumentUID, dateOnly(from), dateOnly(to)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Candle, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertMinuteCandles(ctx context.Context, candles []domain.Candle) error {
|
||||
for _, candle := range candles {
|
||||
if candle.LoadedAt.IsZero() {
|
||||
candle.LoadedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO candles_minute (
|
||||
instrument_uid, ts, open, high, low, close, volume_lots, source, loaded_at
|
||||
) VALUES (
|
||||
:instrument_uid, :trade_date, :open, :high, :low, :close, :volume_lots, :source, :loaded_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, candleRowFromDomain(candle))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListMinuteCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
|
||||
var rows []candleRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT instrument_uid, ts AS trade_date, open, high, low, close, volume_lots, source, loaded_at
|
||||
FROM candles_minute
|
||||
WHERE instrument_uid=? AND ts BETWEEN ? AND ?
|
||||
ORDER BY ts`, instrumentUID, from.UTC(), to.UTC()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Candle, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertFeature(ctx context.Context, feature domain.FeatureSet) error {
|
||||
if feature.CalculatedAt.IsZero() {
|
||||
feature.CalculatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO features (
|
||||
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
|
||||
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
|
||||
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
|
||||
exit_interval_volume, calculated_at
|
||||
) VALUES (
|
||||
:instrument_uid, :trade_date, :r_on, :r_day, :mu_on_60, :mu_on_252, :sigma_on_60,
|
||||
:tstat_on_60, :win_on_60, :ewma_on, :spread_bps, :half_spread_bps, :tick_bps,
|
||||
:adv_20, :expected_cost_bps, :net_edge_bps, :entry_interval_volume,
|
||||
:exit_interval_volume, :calculated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
|
||||
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
|
||||
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
|
||||
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
|
||||
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
|
||||
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
|
||||
net_edge_bps=VALUES(net_edge_bps), entry_interval_volume=VALUES(entry_interval_volume),
|
||||
exit_interval_volume=VALUES(exit_interval_volume), calculated_at=VALUES(calculated_at)`, featureRowFromDomain(feature))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) GetFeature(ctx context.Context, instrumentUID string, tradeDate time.Time) (domain.FeatureSet, error) {
|
||||
var row featureRow
|
||||
if err := r.getContext(ctx, &row, `SELECT * FROM features WHERE instrument_uid=? AND trade_date=?`, instrumentUID, dateOnly(tradeDate)); err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
return row.domain(), nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertSignal(ctx context.Context, signal domain.Signal) error {
|
||||
if signal.CreatedAt.IsZero() {
|
||||
signal.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if signal.ContextJSON == "" {
|
||||
signal.ContextJSON = "{}"
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO signals (
|
||||
trade_date, instrument_uid, decision, score, net_edge_bps, target_notional,
|
||||
target_lots, reject_reason, context_json, created_at
|
||||
) VALUES (
|
||||
:trade_date, :instrument_uid, :decision, :score, :net_edge_bps, :target_notional,
|
||||
:target_lots, :reject_reason, :context_json, :created_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
decision=VALUES(decision), score=VALUES(score), net_edge_bps=VALUES(net_edge_bps),
|
||||
target_notional=VALUES(target_notional), target_lots=VALUES(target_lots),
|
||||
reject_reason=VALUES(reject_reason), context_json=VALUES(context_json),
|
||||
created_at=VALUES(created_at)`, signalRowFromDomain(signal))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListSignals(ctx context.Context, tradeDate time.Time) ([]domain.Signal, error) {
|
||||
var rows []signalRow
|
||||
if err := r.selectContext(ctx, &rows, `SELECT * FROM signals WHERE trade_date=? ORDER BY id`, dateOnly(tradeDate)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Signal, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertOrder(ctx context.Context, order domain.Order) error {
|
||||
now := time.Now().UTC()
|
||||
if order.CreatedAt.IsZero() {
|
||||
order.CreatedAt = now
|
||||
}
|
||||
order.UpdatedAt = now
|
||||
if order.RawStateJSON == "" {
|
||||
order.RawStateJSON = "{}"
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO orders (
|
||||
client_order_id, broker_order_id, account_id_hash, instrument_uid, trade_date,
|
||||
side, order_type, limit_price, quantity_lots, filled_lots, avg_fill_price,
|
||||
status, commission, attempt_no, raw_state_json, created_at, updated_at
|
||||
) VALUES (
|
||||
:client_order_id, :broker_order_id, :account_id_hash, :instrument_uid, :trade_date,
|
||||
:side, :order_type, :limit_price, :quantity_lots, :filled_lots, :avg_fill_price,
|
||||
:status, :commission, :attempt_no, :raw_state_json, :created_at, :updated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
broker_order_id=VALUES(broker_order_id), filled_lots=VALUES(filled_lots),
|
||||
avg_fill_price=VALUES(avg_fill_price), status=VALUES(status),
|
||||
commission=VALUES(commission), raw_state_json=VALUES(raw_state_json),
|
||||
updated_at=VALUES(updated_at)`, orderRowFromDomain(order))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) UpdateOrderStatus(ctx context.Context, clientOrderID string, status domain.OrderStatus, filledLots int64, rawJSON string) error {
|
||||
if rawJSON == "" {
|
||||
rawJSON = "{}"
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
UPDATE orders SET status=?, filled_lots=?, raw_state_json=?, updated_at=UTC_TIMESTAMP(3)
|
||||
WHERE client_order_id=?`, status, filledLots, rawJSON, clientOrderID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListActiveOrders(ctx context.Context, accountIDHash string) ([]domain.Order, error) {
|
||||
var rows []orderRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM orders
|
||||
WHERE account_id_hash=? AND status IN ('NEW','SENT','PARTIALLY_FILLED')
|
||||
ORDER BY created_at`, accountIDHash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Order, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListOrders(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Order, error) {
|
||||
var rows []orderRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM orders
|
||||
WHERE account_id_hash=? AND trade_date BETWEEN ? AND ?
|
||||
ORDER BY created_at`, accountIDHash, dateOnly(from), dateOnly(to)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Order, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertPosition(ctx context.Context, position domain.Position) error {
|
||||
if position.UpdatedAt.IsZero() {
|
||||
position.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO positions (
|
||||
id, account_id_hash, instrument_uid, open_trade_date, lots, lot_size, exit_filled_lots,
|
||||
avg_buy_price, avg_sell_price, status, gross_pnl, net_pnl, commission_total,
|
||||
realized_edge_bps, opened_at, closed_at, updated_at
|
||||
) VALUES (
|
||||
NULLIF(:id, 0), :account_id_hash, :instrument_uid, :open_trade_date, :lots, :lot_size, :exit_filled_lots,
|
||||
:avg_buy_price, :avg_sell_price, :status, :gross_pnl, :net_pnl, :commission_total,
|
||||
:realized_edge_bps, :opened_at, :closed_at, :updated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
lots=VALUES(lots), lot_size=VALUES(lot_size), exit_filled_lots=VALUES(exit_filled_lots), avg_buy_price=VALUES(avg_buy_price), avg_sell_price=VALUES(avg_sell_price),
|
||||
status=VALUES(status), gross_pnl=VALUES(gross_pnl), net_pnl=VALUES(net_pnl),
|
||||
commission_total=VALUES(commission_total), realized_edge_bps=VALUES(realized_edge_bps),
|
||||
opened_at=VALUES(opened_at), closed_at=VALUES(closed_at), updated_at=VALUES(updated_at)`, positionRowFromDomain(position))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListOpenPositions(ctx context.Context, accountIDHash string) ([]domain.Position, error) {
|
||||
var rows []positionRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM positions
|
||||
WHERE account_id_hash=? AND status NOT IN ('NO_POSITION','EXIT_FILLED','QUARANTINE')
|
||||
ORDER BY updated_at`, accountIDHash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Position, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListPositions(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Position, error) {
|
||||
var rows []positionRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM positions
|
||||
WHERE account_id_hash=? AND open_trade_date BETWEEN ? AND ?
|
||||
ORDER BY updated_at`, accountIDHash, dateOnly(from), dateOnly(to)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Position, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error {
|
||||
if event.TS.IsZero() {
|
||||
event.TS = time.Now().UTC()
|
||||
}
|
||||
if event.ContextJSON == "" {
|
||||
event.ContextJSON = "{}"
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO risk_events (ts, severity, event_type, instrument_uid, message, raw_context_json)
|
||||
VALUES (:ts, :severity, :event_type, :instrument_uid, :message, :raw_context_json)`, riskEventRowFromDomain(event))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error) {
|
||||
var sent int
|
||||
err := r.getContext(ctx, &sent, `
|
||||
SELECT orders_sent FROM free_order_counters WHERE trade_date=? AND instrument_uid=?`, dateOnly(tradeDate), instrumentUID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, nil
|
||||
}
|
||||
return sent, err
|
||||
}
|
||||
|
||||
func (r *Repository) IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO free_order_counters (trade_date, instrument_uid, orders_sent)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE orders_sent=orders_sent+VALUES(orders_sent)`, dateOnly(tradeDate), instrumentUID, delta)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error) {
|
||||
var row struct {
|
||||
State string `db:"state"`
|
||||
Halted bool `db:"halted"`
|
||||
HaltReason sql.NullString `db:"halt_reason"`
|
||||
}
|
||||
if err := r.getContext(ctx, &row, `SELECT state, halted, halt_reason FROM system_state WHERE id=1`); err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
return domain.SystemState(row.State), row.Halted, row.HaltReason.String, nil
|
||||
}
|
||||
|
||||
func (r *Repository) SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error {
|
||||
if contextJSON == "" {
|
||||
contextJSON = "{}"
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO system_state (id, state, mode, halted, halt_reason, last_heartbeat, context_json)
|
||||
VALUES (1, ?, ?, ?, ?, UTC_TIMESTAMP(3), ?)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
state=VALUES(state), mode=VALUES(mode), halted=VALUES(halted),
|
||||
halt_reason=VALUES(halt_reason), last_heartbeat=VALUES(last_heartbeat),
|
||||
context_json=VALUES(context_json)`, state, mode, halted, nullableString(reason), contextJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) Unhalt(ctx context.Context, reason string) error {
|
||||
return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
|
||||
state, halted, haltReason, err := repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !halted && state != domain.StateHalted {
|
||||
return fmt.Errorf("system is not halted")
|
||||
}
|
||||
if err := repo.InsertRiskEvent(ctx, domain.RiskEvent{
|
||||
TS: time.Now().UTC(),
|
||||
Severity: domain.SeverityInfo,
|
||||
EventType: "manual_unhalt",
|
||||
Message: fmt.Sprintf("%s (previous halt: %s)", reason, haltReason),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
mode := domain.ModePaper
|
||||
if txRepo, ok := repo.(*Repository); ok {
|
||||
currentMode, err := txRepo.getSystemMode(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mode = currentMode
|
||||
}
|
||||
return repo.SaveSystemState(ctx, domain.StateInit, mode, false, "", `{"manual_unhalt":true}`)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Repository) getSystemMode(ctx context.Context) (domain.Mode, error) {
|
||||
var raw string
|
||||
if err := r.getContext(ctx, &raw, `SELECT mode FROM system_state WHERE id=1`); err != nil {
|
||||
return "", err
|
||||
}
|
||||
mode, err := domain.ParseMode(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
func (r *Repository) WasDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) (bool, error) {
|
||||
var count int
|
||||
if err := r.getContext(ctx, &count, `
|
||||
SELECT COUNT(*) FROM daily_reports WHERE report_date=? AND account_id_hash=?`, dateOnly(reportDate), accountIDHash); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *Repository) MarkDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO daily_reports (report_date, account_id_hash, sent_at)
|
||||
VALUES (?, ?, UTC_TIMESTAMP(3))
|
||||
ON DUPLICATE KEY UPDATE sent_at=sent_at`, dateOnly(reportDate), accountIDHash)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) InsertReconciliation(ctx context.Context, ts time.Time, diffJSON string, hasDiff bool) error {
|
||||
if ts.IsZero() {
|
||||
ts = time.Now().UTC()
|
||||
}
|
||||
if diffJSON == "" {
|
||||
diffJSON = "[]"
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO reconciliations (ts, has_diff, diff_json)
|
||||
VALUES (?, ?, ?)`, ts, hasDiff, diffJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
func dateOnly(t time.Time) time.Time {
|
||||
y, m, d := t.UTC().Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func nullableString(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type instrumentRow struct {
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
Figi sql.NullString `db:"figi"`
|
||||
Ticker string `db:"ticker"`
|
||||
ClassCode string `db:"class_code"`
|
||||
Name string `db:"name"`
|
||||
Lot int64 `db:"lot"`
|
||||
MinPriceIncrement decimal.Decimal `db:"min_price_increment"`
|
||||
Currency string `db:"currency"`
|
||||
Enabled bool `db:"enabled"`
|
||||
FundType string `db:"fund_type"`
|
||||
ExpectedCommissionBpsPerSide decimal.Decimal `db:"expected_commission_bps_per_side"`
|
||||
FreeOrderLimitPerDay int `db:"free_order_limit_per_day"`
|
||||
Quarantine bool `db:"quarantine"`
|
||||
QuarantineReason sql.NullString `db:"quarantine_reason"`
|
||||
ExcludeReason sql.NullString `db:"exclude_reason"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func instrumentRowFromDomain(instrument domain.Instrument) instrumentRow {
|
||||
return instrumentRow{
|
||||
InstrumentUID: instrument.InstrumentUID,
|
||||
Figi: sql.NullString{String: instrument.Figi, Valid: instrument.Figi != ""},
|
||||
Ticker: instrument.Ticker,
|
||||
ClassCode: instrument.ClassCode,
|
||||
Name: instrument.Name,
|
||||
Lot: instrument.Lot,
|
||||
MinPriceIncrement: instrument.MinPriceIncrement,
|
||||
Currency: instrument.Currency,
|
||||
Enabled: instrument.Enabled,
|
||||
FundType: instrument.FundType,
|
||||
ExpectedCommissionBpsPerSide: instrument.ExpectedCommissionBpsPerSide,
|
||||
FreeOrderLimitPerDay: instrument.FreeOrderLimitPerDay,
|
||||
Quarantine: instrument.Quarantine,
|
||||
QuarantineReason: sql.NullString{String: instrument.QuarantineReason, Valid: instrument.QuarantineReason != ""},
|
||||
ExcludeReason: sql.NullString{String: instrument.ExcludeReason, Valid: instrument.ExcludeReason != ""},
|
||||
UpdatedAt: instrument.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func replaceInstrumentRowFromDomain(oldInstrumentUID string, instrument domain.Instrument) map[string]any {
|
||||
row := instrumentRowFromDomain(instrument)
|
||||
return map[string]any{
|
||||
"instrument_uid": row.InstrumentUID,
|
||||
"figi": row.Figi,
|
||||
"ticker": row.Ticker,
|
||||
"class_code": row.ClassCode,
|
||||
"name": row.Name,
|
||||
"lot": row.Lot,
|
||||
"min_price_increment": row.MinPriceIncrement,
|
||||
"currency": row.Currency,
|
||||
"enabled": row.Enabled,
|
||||
"fund_type": row.FundType,
|
||||
"expected_commission_bps_per_side": row.ExpectedCommissionBpsPerSide,
|
||||
"free_order_limit_per_day": row.FreeOrderLimitPerDay,
|
||||
"quarantine": row.Quarantine,
|
||||
"quarantine_reason": row.QuarantineReason,
|
||||
"exclude_reason": row.ExcludeReason,
|
||||
"updated_at": row.UpdatedAt,
|
||||
"old_instrument_uid": oldInstrumentUID,
|
||||
}
|
||||
}
|
||||
|
||||
func (r instrumentRow) domain() domain.Instrument {
|
||||
return domain.Instrument{
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
Figi: r.Figi.String,
|
||||
Ticker: r.Ticker,
|
||||
ClassCode: r.ClassCode,
|
||||
Name: r.Name,
|
||||
Lot: r.Lot,
|
||||
MinPriceIncrement: r.MinPriceIncrement,
|
||||
Currency: r.Currency,
|
||||
Enabled: r.Enabled,
|
||||
FundType: r.FundType,
|
||||
ExpectedCommissionBpsPerSide: r.ExpectedCommissionBpsPerSide,
|
||||
FreeOrderLimitPerDay: r.FreeOrderLimitPerDay,
|
||||
Quarantine: r.Quarantine,
|
||||
QuarantineReason: r.QuarantineReason.String,
|
||||
ExcludeReason: r.ExcludeReason.String,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
//go:build integration
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mariadb"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func TestRepositoryMariaDBMigrationsAndRoundTrip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
container, err := mariadb.Run(ctx,
|
||||
"mariadb:11.4",
|
||||
mariadb.WithDatabase("overnight_bot"),
|
||||
mariadb.WithUsername("bot"),
|
||||
mariadb.WithPassword("bot"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := testcontainers.TerminateContainer(container); err != nil {
|
||||
t.Logf("terminate mariadb: %v", err)
|
||||
}
|
||||
})
|
||||
dsn, err := container.ConnectionString(ctx, "parseTime=true", "loc=UTC", "multiStatements=true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db, err := sqlx.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := ApplyMigrations(ctx, db.DB); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
repo := NewRepository(db)
|
||||
instrument := domain.Instrument{
|
||||
InstrumentUID: "uid-trur",
|
||||
Ticker: "TRUR",
|
||||
ClassCode: "TQTF",
|
||||
Name: "TRUR",
|
||||
Lot: 1,
|
||||
MinPriceIncrement: decimal.NewFromFloat(0.0001),
|
||||
Currency: "RUB",
|
||||
Enabled: true,
|
||||
}
|
||||
if err := repo.ReplaceInstrument(ctx, "PENDING:TRUR", instrument); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tradeDate := time.Date(2026, 6, 7, 0, 0, 0, 0, time.UTC)
|
||||
position := domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid-trur",
|
||||
OpenTradeDate: tradeDate,
|
||||
Lots: 10,
|
||||
AvgBuyPrice: decimal.NewFromInt(100),
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
}
|
||||
if err := repo.UpsertPosition(ctx, position); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
position.Lots = 8
|
||||
position.ExitFilledLots = 2
|
||||
if err := repo.UpsertPosition(ctx, position); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var count int
|
||||
if err := db.GetContext(ctx, &count, `
|
||||
SELECT COUNT(*) FROM positions WHERE account_id_hash='hash' AND instrument_uid='uid-trur' AND open_trade_date=?`, tradeDate); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("positions count=%d, want 1", count)
|
||||
}
|
||||
if err := repo.MarkDailyReportSent(ctx, tradeDate, "hash"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sent, err := repo.WasDailyReportSent(ctx, tradeDate, "hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sent {
|
||||
t.Fatalf("daily report marker was not persisted")
|
||||
}
|
||||
if err := repo.UpsertOrder(ctx, domain.Order{
|
||||
ClientOrderID: "bad",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "missing",
|
||||
TradeDate: tradeDate,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: decimal.NewFromInt(100),
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusSent,
|
||||
RawStateJSON: "{}",
|
||||
}); err == nil {
|
||||
t.Fatalf("expected FK failure for missing instrument")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,338 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
type candleRow struct {
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
Open decimal.Decimal `db:"open"`
|
||||
High decimal.Decimal `db:"high"`
|
||||
Low decimal.Decimal `db:"low"`
|
||||
Close decimal.Decimal `db:"close"`
|
||||
VolumeLots decimal.Decimal `db:"volume_lots"`
|
||||
Source string `db:"source"`
|
||||
LoadedAt time.Time `db:"loaded_at"`
|
||||
}
|
||||
|
||||
func candleRowFromDomain(candle domain.Candle) candleRow {
|
||||
return candleRow{
|
||||
InstrumentUID: candle.InstrumentUID,
|
||||
TradeDate: dateOnly(candle.TradeDate),
|
||||
Open: candle.Open,
|
||||
High: candle.High,
|
||||
Low: candle.Low,
|
||||
Close: candle.Close,
|
||||
VolumeLots: candle.VolumeLots,
|
||||
Source: candle.Source,
|
||||
LoadedAt: candle.LoadedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r candleRow) domain() domain.Candle {
|
||||
return domain.Candle{
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
TradeDate: r.TradeDate,
|
||||
Open: r.Open,
|
||||
High: r.High,
|
||||
Low: r.Low,
|
||||
Close: r.Close,
|
||||
VolumeLots: r.VolumeLots,
|
||||
Source: r.Source,
|
||||
LoadedAt: r.LoadedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type featureRow struct {
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
ROn decimal.Decimal `db:"r_on"`
|
||||
RDay decimal.Decimal `db:"r_day"`
|
||||
MuOn60 decimal.Decimal `db:"mu_on_60"`
|
||||
MuOn252 decimal.Decimal `db:"mu_on_252"`
|
||||
SigmaOn60 decimal.Decimal `db:"sigma_on_60"`
|
||||
TStatOn60 decimal.Decimal `db:"tstat_on_60"`
|
||||
WinOn60 decimal.Decimal `db:"win_on_60"`
|
||||
EWMAOn decimal.Decimal `db:"ewma_on"`
|
||||
SpreadBps decimal.Decimal `db:"spread_bps"`
|
||||
HalfSpreadBps decimal.Decimal `db:"half_spread_bps"`
|
||||
TickBps decimal.Decimal `db:"tick_bps"`
|
||||
ADV20 decimal.Decimal `db:"adv_20"`
|
||||
ExpectedCostBps decimal.Decimal `db:"expected_cost_bps"`
|
||||
NetEdgeBps decimal.Decimal `db:"net_edge_bps"`
|
||||
EntryIntervalVolume decimal.Decimal `db:"entry_interval_volume"`
|
||||
ExitIntervalVolume decimal.Decimal `db:"exit_interval_volume"`
|
||||
CalculatedAt time.Time `db:"calculated_at"`
|
||||
}
|
||||
|
||||
func featureRowFromDomain(feature domain.FeatureSet) featureRow {
|
||||
return featureRow{
|
||||
InstrumentUID: feature.InstrumentUID,
|
||||
TradeDate: dateOnly(feature.TradeDate),
|
||||
ROn: feature.ROn,
|
||||
RDay: feature.RDay,
|
||||
MuOn60: feature.MuOn60,
|
||||
MuOn252: feature.MuOn252,
|
||||
SigmaOn60: feature.SigmaOn60,
|
||||
TStatOn60: feature.TStatOn60,
|
||||
WinOn60: feature.WinOn60,
|
||||
EWMAOn: feature.EWMAOn,
|
||||
SpreadBps: feature.SpreadBps,
|
||||
HalfSpreadBps: feature.HalfSpreadBps,
|
||||
TickBps: feature.TickBps,
|
||||
ADV20: feature.ADV20,
|
||||
ExpectedCostBps: feature.ExpectedCostBps,
|
||||
NetEdgeBps: feature.NetEdgeBps,
|
||||
EntryIntervalVolume: feature.EntryIntervalVolume,
|
||||
ExitIntervalVolume: feature.ExitIntervalVolume,
|
||||
CalculatedAt: feature.CalculatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r featureRow) domain() domain.FeatureSet {
|
||||
return domain.FeatureSet{
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
TradeDate: r.TradeDate,
|
||||
ROn: r.ROn,
|
||||
RDay: r.RDay,
|
||||
MuOn60: r.MuOn60,
|
||||
MuOn252: r.MuOn252,
|
||||
SigmaOn60: r.SigmaOn60,
|
||||
TStatOn60: r.TStatOn60,
|
||||
WinOn60: r.WinOn60,
|
||||
EWMAOn: r.EWMAOn,
|
||||
SpreadBps: r.SpreadBps,
|
||||
HalfSpreadBps: r.HalfSpreadBps,
|
||||
TickBps: r.TickBps,
|
||||
ADV20: r.ADV20,
|
||||
ExpectedCostBps: r.ExpectedCostBps,
|
||||
NetEdgeBps: r.NetEdgeBps,
|
||||
EntryIntervalVolume: r.EntryIntervalVolume,
|
||||
ExitIntervalVolume: r.ExitIntervalVolume,
|
||||
CalculatedAt: r.CalculatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type signalRow struct {
|
||||
ID int64 `db:"id"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
Decision string `db:"decision"`
|
||||
Score decimal.Decimal `db:"score"`
|
||||
NetEdgeBps decimal.Decimal `db:"net_edge_bps"`
|
||||
TargetNotional decimal.Decimal `db:"target_notional"`
|
||||
TargetLots int64 `db:"target_lots"`
|
||||
RejectReason sql.NullString `db:"reject_reason"`
|
||||
ContextJSON sql.NullString `db:"context_json"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
func signalRowFromDomain(signal domain.Signal) signalRow {
|
||||
return signalRow{
|
||||
ID: signal.ID,
|
||||
TradeDate: dateOnly(signal.TradeDate),
|
||||
InstrumentUID: signal.InstrumentUID,
|
||||
Decision: string(signal.Decision),
|
||||
Score: signal.Score,
|
||||
NetEdgeBps: signal.NetEdgeBps,
|
||||
TargetNotional: signal.TargetNotional,
|
||||
TargetLots: signal.TargetLots,
|
||||
RejectReason: sql.NullString{String: signal.RejectReason, Valid: signal.RejectReason != ""},
|
||||
ContextJSON: sql.NullString{String: signal.ContextJSON, Valid: signal.ContextJSON != ""},
|
||||
CreatedAt: signal.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r signalRow) domain() domain.Signal {
|
||||
return domain.Signal{
|
||||
ID: r.ID,
|
||||
TradeDate: r.TradeDate,
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
Decision: domain.SignalDecision(r.Decision),
|
||||
Score: r.Score,
|
||||
NetEdgeBps: r.NetEdgeBps,
|
||||
TargetNotional: r.TargetNotional,
|
||||
TargetLots: r.TargetLots,
|
||||
RejectReason: r.RejectReason.String,
|
||||
ContextJSON: r.ContextJSON.String,
|
||||
CreatedAt: r.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type orderRow struct {
|
||||
ClientOrderID string `db:"client_order_id"`
|
||||
BrokerOrderID sql.NullString `db:"broker_order_id"`
|
||||
AccountIDHash string `db:"account_id_hash"`
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
Side string `db:"side"`
|
||||
OrderType string `db:"order_type"`
|
||||
LimitPrice decimal.Decimal `db:"limit_price"`
|
||||
QuantityLots int64 `db:"quantity_lots"`
|
||||
FilledLots int64 `db:"filled_lots"`
|
||||
AvgFillPrice decimal.Decimal `db:"avg_fill_price"`
|
||||
Status string `db:"status"`
|
||||
Commission decimal.Decimal `db:"commission"`
|
||||
AttemptNo int `db:"attempt_no"`
|
||||
RawStateJSON sql.NullString `db:"raw_state_json"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func orderRowFromDomain(order domain.Order) orderRow {
|
||||
return orderRow{
|
||||
ClientOrderID: order.ClientOrderID,
|
||||
BrokerOrderID: sql.NullString{
|
||||
String: order.BrokerOrderID,
|
||||
Valid: order.BrokerOrderID != "",
|
||||
},
|
||||
AccountIDHash: order.AccountIDHash,
|
||||
InstrumentUID: order.InstrumentUID,
|
||||
TradeDate: dateOnly(order.TradeDate),
|
||||
Side: string(order.Side),
|
||||
OrderType: string(order.OrderType),
|
||||
LimitPrice: order.LimitPrice,
|
||||
QuantityLots: order.QuantityLots,
|
||||
FilledLots: order.FilledLots,
|
||||
AvgFillPrice: order.AvgFillPrice,
|
||||
Status: string(order.Status),
|
||||
Commission: order.Commission,
|
||||
AttemptNo: order.AttemptNo,
|
||||
RawStateJSON: sql.NullString{
|
||||
String: order.RawStateJSON,
|
||||
Valid: order.RawStateJSON != "",
|
||||
},
|
||||
CreatedAt: order.CreatedAt,
|
||||
UpdatedAt: order.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r orderRow) domain() domain.Order {
|
||||
return domain.Order{
|
||||
ClientOrderID: r.ClientOrderID,
|
||||
BrokerOrderID: r.BrokerOrderID.String,
|
||||
AccountIDHash: r.AccountIDHash,
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
TradeDate: r.TradeDate,
|
||||
Side: domain.Side(r.Side),
|
||||
OrderType: domain.OrderType(r.OrderType),
|
||||
LimitPrice: r.LimitPrice,
|
||||
QuantityLots: r.QuantityLots,
|
||||
FilledLots: r.FilledLots,
|
||||
AvgFillPrice: r.AvgFillPrice,
|
||||
Status: domain.OrderStatus(r.Status),
|
||||
Commission: r.Commission,
|
||||
AttemptNo: r.AttemptNo,
|
||||
RawStateJSON: r.RawStateJSON.String,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type positionRow struct {
|
||||
ID int64 `db:"id"`
|
||||
AccountIDHash string `db:"account_id_hash"`
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
OpenTradeDate time.Time `db:"open_trade_date"`
|
||||
Lots int64 `db:"lots"`
|
||||
Lot int64 `db:"lot_size"`
|
||||
ExitFilledLots int64 `db:"exit_filled_lots"`
|
||||
AvgBuyPrice decimal.Decimal `db:"avg_buy_price"`
|
||||
AvgSellPrice decimal.Decimal `db:"avg_sell_price"`
|
||||
Status string `db:"status"`
|
||||
GrossPnL decimal.Decimal `db:"gross_pnl"`
|
||||
NetPnL decimal.Decimal `db:"net_pnl"`
|
||||
CommissionTotal decimal.Decimal `db:"commission_total"`
|
||||
RealizedEdgeBps decimal.Decimal `db:"realized_edge_bps"`
|
||||
OpenedAt sql.NullTime `db:"opened_at"`
|
||||
ClosedAt sql.NullTime `db:"closed_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func positionRowFromDomain(position domain.Position) positionRow {
|
||||
lot := position.Lot
|
||||
if lot <= 0 {
|
||||
lot = 1
|
||||
}
|
||||
return positionRow{
|
||||
ID: position.ID,
|
||||
AccountIDHash: position.AccountIDHash,
|
||||
InstrumentUID: position.InstrumentUID,
|
||||
OpenTradeDate: dateOnly(position.OpenTradeDate),
|
||||
Lots: position.Lots,
|
||||
Lot: lot,
|
||||
ExitFilledLots: position.ExitFilledLots,
|
||||
AvgBuyPrice: position.AvgBuyPrice,
|
||||
AvgSellPrice: position.AvgSellPrice,
|
||||
Status: string(position.Status),
|
||||
GrossPnL: position.GrossPnL,
|
||||
NetPnL: position.NetPnL,
|
||||
CommissionTotal: position.CommissionTotal,
|
||||
RealizedEdgeBps: position.RealizedEdgeBps,
|
||||
OpenedAt: nullableTime(position.OpenedAt),
|
||||
ClosedAt: nullableTime(position.ClosedAt),
|
||||
UpdatedAt: position.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r positionRow) domain() domain.Position {
|
||||
return domain.Position{
|
||||
ID: r.ID,
|
||||
AccountIDHash: r.AccountIDHash,
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
OpenTradeDate: r.OpenTradeDate,
|
||||
Lots: r.Lots,
|
||||
Lot: r.Lot,
|
||||
ExitFilledLots: r.ExitFilledLots,
|
||||
AvgBuyPrice: r.AvgBuyPrice,
|
||||
AvgSellPrice: r.AvgSellPrice,
|
||||
Status: domain.PositionStatus(r.Status),
|
||||
GrossPnL: r.GrossPnL,
|
||||
NetPnL: r.NetPnL,
|
||||
CommissionTotal: r.CommissionTotal,
|
||||
RealizedEdgeBps: r.RealizedEdgeBps,
|
||||
OpenedAt: timePtr(r.OpenedAt),
|
||||
ClosedAt: timePtr(r.ClosedAt),
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type riskEventRow struct {
|
||||
TS time.Time `db:"ts"`
|
||||
Severity string `db:"severity"`
|
||||
EventType string `db:"event_type"`
|
||||
InstrumentUID sql.NullString `db:"instrument_uid"`
|
||||
Message string `db:"message"`
|
||||
ContextJSON string `db:"raw_context_json"`
|
||||
}
|
||||
|
||||
func riskEventRowFromDomain(event domain.RiskEvent) riskEventRow {
|
||||
return riskEventRow{
|
||||
TS: event.TS,
|
||||
Severity: string(event.Severity),
|
||||
EventType: event.EventType,
|
||||
InstrumentUID: sql.NullString{String: event.InstrumentUID, Valid: event.InstrumentUID != ""},
|
||||
Message: event.Message,
|
||||
ContextJSON: event.ContextJSON,
|
||||
}
|
||||
}
|
||||
|
||||
func nullableTime(t *time.Time) sql.NullTime {
|
||||
if t == nil {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return sql.NullTime{Time: *t, Valid: true}
|
||||
}
|
||||
|
||||
func timePtr(t sql.NullTime) *time.Time {
|
||||
if !t.Valid {
|
||||
return nil
|
||||
}
|
||||
return &t.Time
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
RunInTx(ctx context.Context, fn func(ctx context.Context, repo Repository) error) error
|
||||
|
||||
UpsertInstrument(ctx context.Context, instrument domain.Instrument) error
|
||||
ReplaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error
|
||||
ListInstruments(ctx context.Context, includeDisabled bool) ([]domain.Instrument, error)
|
||||
QuarantineInstrument(ctx context.Context, instrumentUID, reason string) error
|
||||
|
||||
UpsertDailyCandles(ctx context.Context, candles []domain.Candle) error
|
||||
ListDailyCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error)
|
||||
UpsertMinuteCandles(ctx context.Context, candles []domain.Candle) error
|
||||
ListMinuteCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error)
|
||||
|
||||
UpsertFeature(ctx context.Context, feature domain.FeatureSet) error
|
||||
GetFeature(ctx context.Context, instrumentUID string, tradeDate time.Time) (domain.FeatureSet, error)
|
||||
|
||||
UpsertSignal(ctx context.Context, signal domain.Signal) error
|
||||
ListSignals(ctx context.Context, tradeDate time.Time) ([]domain.Signal, error)
|
||||
|
||||
UpsertOrder(ctx context.Context, order domain.Order) error
|
||||
UpdateOrderStatus(ctx context.Context, clientOrderID string, status domain.OrderStatus, filledLots int64, rawJSON string) error
|
||||
ListActiveOrders(ctx context.Context, accountIDHash string) ([]domain.Order, error)
|
||||
ListOrders(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Order, error)
|
||||
|
||||
UpsertPosition(ctx context.Context, position domain.Position) error
|
||||
ListOpenPositions(ctx context.Context, accountIDHash string) ([]domain.Position, error)
|
||||
ListPositions(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Position, error)
|
||||
|
||||
InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error
|
||||
GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error)
|
||||
IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error
|
||||
|
||||
GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error)
|
||||
SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error
|
||||
Unhalt(ctx context.Context, reason string) error
|
||||
WasDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) (bool, error)
|
||||
MarkDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) error
|
||||
|
||||
InsertReconciliation(ctx context.Context, ts time.Time, diffJSON string, hasDiff bool) error
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package risk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func TestFreeOrderBudgetSubmittedPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewMemoryFreeOrderStore()
|
||||
budget := NewFreeOrderBudget(store)
|
||||
instr := domain.Instrument{InstrumentUID: "uid", FreeOrderLimitPerDay: 2}
|
||||
date := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
if _, err := budget.Check(ctx, date, instr, 2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := budget.Submitted(ctx, date, instr.InstrumentUID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := budget.Check(ctx, date, instr, 2); !errors.Is(err, ErrFreeOrderBudget) {
|
||||
t.Fatalf("expected ErrFreeOrderBudget, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package risk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
type EventSink interface {
|
||||
InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error
|
||||
SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
sink EventSink
|
||||
cfg ManagerConfig
|
||||
}
|
||||
|
||||
type ManagerConfig struct {
|
||||
MaxDailyLossPct decimal.Decimal
|
||||
MaxWeeklyLossPct decimal.Decimal
|
||||
MaxMonthlyDrawdownPct decimal.Decimal
|
||||
MaxAvgSlippageBps10Trades decimal.Decimal
|
||||
MaxOpenPositions int
|
||||
MinTimeToClose time.Duration
|
||||
MaxQuoteAge time.Duration
|
||||
}
|
||||
|
||||
type PreTradeInput struct {
|
||||
Portfolio domain.Portfolio
|
||||
OpenPositions int
|
||||
DailyPnL decimal.Decimal
|
||||
WeeklyPnL decimal.Decimal
|
||||
MonthlyDrawdownPct decimal.Decimal
|
||||
AvgSlippageBps10 decimal.Decimal
|
||||
TradingStatus domain.TradingStatus
|
||||
QuoteReceivedAt time.Time
|
||||
Now time.Time
|
||||
MarketClose time.Time
|
||||
DatabaseUnavailable bool
|
||||
UnknownBrokerOrder bool
|
||||
UnknownBrokerHolding bool
|
||||
}
|
||||
|
||||
type PreTradeResult struct {
|
||||
Allowed bool
|
||||
Reason string
|
||||
}
|
||||
|
||||
func NewManager(sink EventSink, cfg ManagerConfig) Manager {
|
||||
return Manager{sink: sink, cfg: cfg}
|
||||
}
|
||||
|
||||
func (m Manager) Halt(ctx context.Context, mode domain.Mode, eventType, reason string, instrumentUID string) error {
|
||||
if m.sink == nil {
|
||||
return nil
|
||||
}
|
||||
event := domain.RiskEvent{
|
||||
TS: time.Now().UTC(),
|
||||
Severity: domain.SeverityCritical,
|
||||
EventType: eventType,
|
||||
InstrumentUID: instrumentUID,
|
||||
Message: reason,
|
||||
}
|
||||
if err := m.sink.InsertRiskEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("insert halt risk event: %w", err)
|
||||
}
|
||||
if err := m.sink.SaveSystemState(ctx, domain.StateHalted, mode, true, reason, "{}"); err != nil {
|
||||
return fmt.Errorf("persist halt state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) PreTradeCheck(input PreTradeInput) PreTradeResult {
|
||||
now := input.Now
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
switch {
|
||||
case input.DatabaseUnavailable:
|
||||
return reject("database_unavailable")
|
||||
case input.UnknownBrokerOrder:
|
||||
return reject("unknown_broker_order")
|
||||
case input.UnknownBrokerHolding:
|
||||
return reject("unknown_broker_position")
|
||||
case input.TradingStatus == domain.TradingStatusUnknown:
|
||||
return reject("trading_status_unknown_before_order")
|
||||
case input.TradingStatus != domain.TradingStatusNormal:
|
||||
return reject("trading_status_not_normal")
|
||||
case m.cfg.MaxOpenPositions > 0 && input.OpenPositions >= m.cfg.MaxOpenPositions:
|
||||
return reject("max_open_positions")
|
||||
case DailyLossBreached(input.DailyPnL, input.Portfolio.Equity, m.cfg.MaxDailyLossPct):
|
||||
return reject("max_daily_loss")
|
||||
case DailyLossBreached(input.WeeklyPnL, input.Portfolio.Equity, m.cfg.MaxWeeklyLossPct):
|
||||
return reject("max_weekly_loss")
|
||||
case m.cfg.MaxMonthlyDrawdownPct.IsPositive() && input.MonthlyDrawdownPct.GreaterThanOrEqual(m.cfg.MaxMonthlyDrawdownPct):
|
||||
return reject("max_monthly_drawdown")
|
||||
case m.cfg.MaxAvgSlippageBps10Trades.IsPositive() && input.AvgSlippageBps10.GreaterThan(m.cfg.MaxAvgSlippageBps10Trades):
|
||||
return reject("max_avg_slippage_bps_10_trades")
|
||||
case m.cfg.MaxQuoteAge > 0 && !input.QuoteReceivedAt.IsZero() && now.Sub(input.QuoteReceivedAt) > m.cfg.MaxQuoteAge:
|
||||
return reject("quote_age_too_high")
|
||||
case m.cfg.MinTimeToClose > 0 && !input.MarketClose.IsZero() && input.MarketClose.Sub(now) < m.cfg.MinTimeToClose:
|
||||
return reject("min_time_to_close_sec")
|
||||
default:
|
||||
return PreTradeResult{Allowed: true}
|
||||
}
|
||||
}
|
||||
|
||||
func DailyLossBreached(pnl, equity, maxLossPct decimal.Decimal) bool {
|
||||
if !equity.IsPositive() || !maxLossPct.IsPositive() {
|
||||
return false
|
||||
}
|
||||
limit := equity.Mul(maxLossPct).Neg()
|
||||
return pnl.LessThanOrEqual(limit)
|
||||
}
|
||||
|
||||
func CommissionBreached(actualCommission decimal.Decimal, requireZero bool) bool {
|
||||
return requireZero && actualCommission.IsPositive()
|
||||
}
|
||||
|
||||
func reject(reason string) PreTradeResult {
|
||||
return PreTradeResult{Allowed: false, Reason: reason}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package risk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/money"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoSizingCapacity = errors.New("no sizing capacity")
|
||||
ErrFreeOrderBudget = errors.New("free order budget is insufficient")
|
||||
)
|
||||
|
||||
type SizingConfig struct {
|
||||
MaxPositionPct decimal.Decimal
|
||||
MaxTotalExposurePct decimal.Decimal
|
||||
MaxParticipationRate decimal.Decimal
|
||||
CashUsageBuffer decimal.Decimal
|
||||
RiskBudgetPerInstrumentPct decimal.Decimal
|
||||
MinOrderNotionalRUB decimal.Decimal
|
||||
}
|
||||
|
||||
type SizingInput struct {
|
||||
Portfolio domain.Portfolio
|
||||
SelectedInstruments int
|
||||
LimitPrice decimal.Decimal
|
||||
Lot int64
|
||||
EntryIntervalVolume decimal.Decimal
|
||||
ExitIntervalVolume decimal.Decimal
|
||||
Q05OvernightAbs decimal.Decimal
|
||||
}
|
||||
|
||||
type SizingResult struct {
|
||||
TargetNotional decimal.Decimal
|
||||
Lots int64
|
||||
Reason string
|
||||
Limits map[string]decimal.Decimal
|
||||
}
|
||||
|
||||
type Sizer struct {
|
||||
cfg SizingConfig
|
||||
sizeFactor decimal.Decimal
|
||||
}
|
||||
|
||||
func NewSizer(cfg SizingConfig) Sizer {
|
||||
return Sizer{cfg: cfg, sizeFactor: decimal.NewFromInt(1)}
|
||||
}
|
||||
|
||||
func (s Sizer) WithSizeFactor(factor decimal.Decimal) Sizer {
|
||||
if !factor.IsPositive() {
|
||||
factor = decimal.NewFromInt(1)
|
||||
}
|
||||
s.sizeFactor = factor
|
||||
return s
|
||||
}
|
||||
|
||||
func (s Sizer) Size(input SizingInput) SizingResult {
|
||||
limits := make(map[string]decimal.Decimal, 6)
|
||||
if input.SelectedInstruments <= 0 {
|
||||
input.SelectedInstruments = 1
|
||||
}
|
||||
capLimit := input.Portfolio.Equity.Mul(s.cfg.MaxPositionPct)
|
||||
exposureLimit := input.Portfolio.Equity.Mul(s.cfg.MaxTotalExposurePct).
|
||||
Div(decimal.NewFromInt(int64(input.SelectedInstruments)))
|
||||
liquidityLimit := money.Min(input.EntryIntervalVolume, input.ExitIntervalVolume).
|
||||
Mul(s.cfg.MaxParticipationRate)
|
||||
cashLimit := input.Portfolio.Cash.Mul(s.cfg.CashUsageBuffer)
|
||||
riskLimit := capLimit
|
||||
if input.Q05OvernightAbs.IsPositive() {
|
||||
riskBudget := input.Portfolio.Equity.Mul(s.cfg.RiskBudgetPerInstrumentPct)
|
||||
riskLimit = riskBudget.Div(input.Q05OvernightAbs)
|
||||
}
|
||||
limits["cap"] = capLimit
|
||||
limits["exposure"] = exposureLimit
|
||||
limits["liquidity"] = liquidityLimit
|
||||
limits["risk"] = riskLimit
|
||||
limits["cash"] = cashLimit
|
||||
|
||||
sizeFactor := s.effectiveSizeFactor()
|
||||
limits["size_factor"] = sizeFactor
|
||||
target := money.Min(capLimit, exposureLimit, liquidityLimit, riskLimit, cashLimit).Mul(sizeFactor)
|
||||
if !target.IsPositive() || !input.LimitPrice.IsPositive() || input.Lot <= 0 {
|
||||
return SizingResult{Reason: "non_positive_limit", Limits: limits}
|
||||
}
|
||||
lotNotional := input.LimitPrice.Mul(decimal.NewFromInt(input.Lot))
|
||||
lots := target.Div(lotNotional).Floor().IntPart()
|
||||
notional := lotNotional.Mul(decimal.NewFromInt(lots))
|
||||
if lots < 1 {
|
||||
return SizingResult{TargetNotional: notional, Lots: lots, Reason: "lots_below_one", Limits: limits}
|
||||
}
|
||||
if notional.LessThan(s.cfg.MinOrderNotionalRUB) {
|
||||
return SizingResult{TargetNotional: notional, Lots: 0, Reason: "min_order_notional", Limits: limits}
|
||||
}
|
||||
return SizingResult{TargetNotional: notional, Lots: lots, Limits: limits}
|
||||
}
|
||||
|
||||
func (s Sizer) effectiveSizeFactor() decimal.Decimal {
|
||||
if !s.sizeFactor.IsPositive() {
|
||||
return decimal.NewFromInt(1)
|
||||
}
|
||||
return s.sizeFactor
|
||||
}
|
||||
|
||||
type FreeOrderStore interface {
|
||||
GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error)
|
||||
IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error
|
||||
}
|
||||
|
||||
type FreeOrderBudget struct {
|
||||
store FreeOrderStore
|
||||
}
|
||||
|
||||
func NewFreeOrderBudget(store FreeOrderStore) FreeOrderBudget {
|
||||
return FreeOrderBudget{store: store}
|
||||
}
|
||||
|
||||
func (b FreeOrderBudget) Check(ctx context.Context, tradeDate time.Time, instr domain.Instrument, ordersNeeded int) (int, error) {
|
||||
if instr.FreeOrderLimitPerDay <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
sent, err := b.store.GetFreeOrdersSent(ctx, tradeDate, instr.InstrumentUID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
remaining := instr.FreeOrderLimitPerDay - sent
|
||||
if remaining < ordersNeeded {
|
||||
return remaining, ErrFreeOrderBudget
|
||||
}
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
func (b FreeOrderBudget) Submitted(ctx context.Context, tradeDate time.Time, instrumentUID string) error {
|
||||
return b.store.IncrementFreeOrders(ctx, tradeDate, instrumentUID, 1)
|
||||
}
|
||||
|
||||
type MemoryFreeOrderStore struct {
|
||||
mu sync.Mutex
|
||||
counts map[string]int
|
||||
}
|
||||
|
||||
func NewMemoryFreeOrderStore() *MemoryFreeOrderStore {
|
||||
return &MemoryFreeOrderStore{counts: make(map[string]int)}
|
||||
}
|
||||
|
||||
func (s *MemoryFreeOrderStore) GetFreeOrdersSent(_ context.Context, tradeDate time.Time, instrumentUID string) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.counts[freeOrderKey(tradeDate, instrumentUID)], nil
|
||||
}
|
||||
|
||||
func (s *MemoryFreeOrderStore) IncrementFreeOrders(_ context.Context, tradeDate time.Time, instrumentUID string, delta int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.counts[freeOrderKey(tradeDate, instrumentUID)] += delta
|
||||
return nil
|
||||
}
|
||||
|
||||
func freeOrderKey(tradeDate time.Time, instrumentUID string) string {
|
||||
return tradeDate.Format("2006-01-02") + "|" + instrumentUID
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package risk
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func rd(raw string) decimal.Decimal {
|
||||
v, err := decimal.NewFromString(raw)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func TestSizerTakesMinimumOfLimits(t *testing.T) {
|
||||
sizer := NewSizer(SizingConfig{
|
||||
MaxPositionPct: rd("0.10"),
|
||||
MaxTotalExposurePct: rd("0.50"),
|
||||
MaxParticipationRate: rd("0.01"),
|
||||
CashUsageBuffer: rd("0.95"),
|
||||
RiskBudgetPerInstrumentPct: rd("0.005"),
|
||||
MinOrderNotionalRUB: rd("1000"),
|
||||
})
|
||||
got := sizer.Size(SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("90000")},
|
||||
SelectedInstruments: 5,
|
||||
LimitPrice: rd("100"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("1000000"),
|
||||
ExitIntervalVolume: rd("1000000"),
|
||||
Q05OvernightAbs: rd("0.05"),
|
||||
})
|
||||
if got.Lots != 100 || !got.TargetNotional.Equal(rd("10000")) {
|
||||
t.Fatalf("unexpected sizing: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizerMinOrderGate(t *testing.T) {
|
||||
sizer := NewSizer(SizingConfig{
|
||||
MaxPositionPct: rd("0.10"),
|
||||
MaxTotalExposurePct: rd("0.50"),
|
||||
MaxParticipationRate: rd("0.01"),
|
||||
CashUsageBuffer: rd("0.95"),
|
||||
RiskBudgetPerInstrumentPct: rd("0.005"),
|
||||
MinOrderNotionalRUB: rd("1000"),
|
||||
})
|
||||
got := sizer.Size(SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("10000"), Cash: rd("10000")},
|
||||
SelectedInstruments: 1,
|
||||
LimitPrice: rd("999"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("1000000"),
|
||||
ExitIntervalVolume: rd("1000000"),
|
||||
Q05OvernightAbs: rd("0.05"),
|
||||
})
|
||||
if got.Lots != 0 || got.Reason != "min_order_notional" {
|
||||
t.Fatalf("unexpected min order gate: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizerBindingLimits(t *testing.T) {
|
||||
sizer := NewSizer(SizingConfig{
|
||||
MaxPositionPct: rd("0.10"),
|
||||
MaxTotalExposurePct: rd("0.50"),
|
||||
MaxParticipationRate: rd("0.01"),
|
||||
CashUsageBuffer: rd("0.95"),
|
||||
RiskBudgetPerInstrumentPct: rd("0.005"),
|
||||
MinOrderNotionalRUB: rd("1"),
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
input SizingInput
|
||||
want decimal.Decimal
|
||||
}{
|
||||
{
|
||||
name: "cap",
|
||||
input: SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
|
||||
SelectedInstruments: 1,
|
||||
LimitPrice: rd("100"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("5000000"),
|
||||
ExitIntervalVolume: rd("5000000"),
|
||||
},
|
||||
want: rd("10000"),
|
||||
},
|
||||
{
|
||||
name: "exposure",
|
||||
input: SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
|
||||
SelectedInstruments: 10,
|
||||
LimitPrice: rd("100"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("5000000"),
|
||||
ExitIntervalVolume: rd("5000000"),
|
||||
},
|
||||
want: rd("5000"),
|
||||
},
|
||||
{
|
||||
name: "liquidity",
|
||||
input: SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
|
||||
SelectedInstruments: 1,
|
||||
LimitPrice: rd("100"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("300000"),
|
||||
ExitIntervalVolume: rd("500000"),
|
||||
},
|
||||
want: rd("3000"),
|
||||
},
|
||||
{
|
||||
name: "risk",
|
||||
input: SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
|
||||
SelectedInstruments: 1,
|
||||
LimitPrice: rd("100"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("5000000"),
|
||||
ExitIntervalVolume: rd("5000000"),
|
||||
Q05OvernightAbs: rd("0.10"),
|
||||
},
|
||||
want: rd("5000"),
|
||||
},
|
||||
{
|
||||
name: "cash",
|
||||
input: SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("2000")},
|
||||
SelectedInstruments: 1,
|
||||
LimitPrice: rd("100"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("5000000"),
|
||||
ExitIntervalVolume: rd("5000000"),
|
||||
},
|
||||
want: rd("1900"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := sizer.Size(tt.input)
|
||||
if !got.TargetNotional.Equal(tt.want) {
|
||||
t.Fatalf("target=%s, want %s limits=%v", got.TargetNotional, tt.want, got.Limits)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizerAppliesSizeReductionFactor(t *testing.T) {
|
||||
sizer := NewSizer(SizingConfig{
|
||||
MaxPositionPct: rd("1"),
|
||||
MaxTotalExposurePct: rd("1"),
|
||||
MaxParticipationRate: rd("1"),
|
||||
CashUsageBuffer: rd("1"),
|
||||
RiskBudgetPerInstrumentPct: rd("1"),
|
||||
MinOrderNotionalRUB: rd("1"),
|
||||
}).WithSizeFactor(rd("0.5"))
|
||||
got := sizer.Size(SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: rd("10000"), Cash: rd("10000")},
|
||||
SelectedInstruments: 1,
|
||||
LimitPrice: rd("100"),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: rd("10000"),
|
||||
ExitIntervalVolume: rd("10000"),
|
||||
Q05OvernightAbs: rd("1"),
|
||||
})
|
||||
if got.Lots != 50 || !got.TargetNotional.Equal(rd("5000")) {
|
||||
t.Fatalf("unexpected reduced sizing: %+v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,905 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/execution"
|
||||
"overnight-trading-bot/internal/features"
|
||||
"overnight-trading-bot/internal/instruments"
|
||||
"overnight-trading-bot/internal/marketdata"
|
||||
"overnight-trading-bot/internal/money"
|
||||
"overnight-trading-bot/internal/notify"
|
||||
"overnight-trading-bot/internal/position"
|
||||
"overnight-trading-bot/internal/reconciliation"
|
||||
"overnight-trading-bot/internal/report"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
"overnight-trading-bot/internal/risk"
|
||||
"overnight-trading-bot/internal/signal"
|
||||
"overnight-trading-bot/internal/statemachine"
|
||||
"overnight-trading-bot/internal/timeutil"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
const (
|
||||
sizeReductionWindowTrades = 20
|
||||
sizeReductionFactor = 0.5
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Mode domain.Mode
|
||||
Location *time.Location
|
||||
RollingLong int
|
||||
TickInterval time.Duration
|
||||
EntrySignalTime timeutil.TimeOfDay
|
||||
EntryWindowStart timeutil.TimeOfDay
|
||||
EntryWindowEnd timeutil.TimeOfDay
|
||||
NoNewEntryAfter timeutil.TimeOfDay
|
||||
ExitWatchStart timeutil.TimeOfDay
|
||||
ExitWindowStart timeutil.TimeOfDay
|
||||
ExitWindowEnd timeutil.TimeOfDay
|
||||
HardExitDeadline timeutil.TimeOfDay
|
||||
QuoteDepth int32
|
||||
MaxQuoteAge time.Duration
|
||||
OrderPollInterval time.Duration
|
||||
PassiveImproveTicks int
|
||||
MaxEntryOrderAttempts int
|
||||
MaxExitOrderAttempts int
|
||||
MinTimeToClose time.Duration
|
||||
MaxClockDrift time.Duration
|
||||
APIOutageHalt time.Duration
|
||||
}
|
||||
|
||||
type Services struct {
|
||||
Repo repository.Repository
|
||||
Gateway tinvest.Gateway
|
||||
Registry instruments.Registry
|
||||
MarketData marketdata.Loader
|
||||
Features features.Pipeline
|
||||
Signals signal.Engine
|
||||
Sizer risk.Sizer
|
||||
FreeOrders risk.FreeOrderBudget
|
||||
Risk risk.Manager
|
||||
Execution *execution.Engine
|
||||
Positions position.Manager
|
||||
Reconcile reconciliation.Engine
|
||||
Notifier notify.Notifier
|
||||
AccountID string
|
||||
AccountIDHash string
|
||||
Log *slog.Logger
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
clock timeutil.Clock
|
||||
sm statemachine.System
|
||||
cfg Config
|
||||
svc Services
|
||||
|
||||
infraFailedSince time.Time
|
||||
}
|
||||
|
||||
func New(clock timeutil.Clock, sm statemachine.System, cfg Config, svc Services) Scheduler {
|
||||
if cfg.TickInterval <= 0 {
|
||||
cfg.TickInterval = 30 * time.Second
|
||||
}
|
||||
if cfg.Location == nil {
|
||||
cfg.Location = time.UTC
|
||||
}
|
||||
return Scheduler{clock: clock, sm: sm, cfg: cfg, svc: svc}
|
||||
}
|
||||
|
||||
func (s *Scheduler) Run(ctx context.Context) error {
|
||||
for {
|
||||
if err := s.Step(ctx); err != nil {
|
||||
if errors.Is(err, statemachine.ErrSystemHalted) {
|
||||
s.logWarn("scheduler paused in HALT", "err", err)
|
||||
} else if err := s.halt(ctx, "scheduler_error", err.Error(), ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !s.clock.Sleep(ctx.Done(), s.cfg.TickInterval) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) Step(ctx context.Context) error {
|
||||
if err := s.checkInfrastructure(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
now := s.clock.Now().In(s.cfg.Location)
|
||||
phase := s.phase(now)
|
||||
switch phase {
|
||||
case domain.StateWaitExitWindow:
|
||||
return s.waitExit(ctx, now)
|
||||
case domain.StatePlaceExitOrders:
|
||||
return s.placeExitOrders(ctx, now)
|
||||
case domain.StateMonitorExitOrders:
|
||||
return s.monitorExitOrders(ctx, now)
|
||||
case domain.StateReconcile:
|
||||
return s.failOpenPositionsAtHardDeadline(ctx)
|
||||
case domain.StateGenerateSignals:
|
||||
return s.prepareSignals(ctx, now)
|
||||
case domain.StatePlaceEntryOrders:
|
||||
return s.placeEntryOrders(ctx, now)
|
||||
case domain.StateMonitorEntryOrders:
|
||||
return s.monitorEntryOrders(ctx, now)
|
||||
case domain.StateHoldOvernight:
|
||||
return s.holdOvernight(ctx)
|
||||
default:
|
||||
return s.sm.Heartbeat(ctx, domain.StateSleep)
|
||||
}
|
||||
}
|
||||
|
||||
func (s Scheduler) phase(now time.Time) domain.SystemState {
|
||||
tod := sinceMidnight(now)
|
||||
switch {
|
||||
case tod >= s.cfg.ExitWatchStart.Duration && tod < s.cfg.ExitWindowStart.Duration:
|
||||
return domain.StateWaitExitWindow
|
||||
case tod >= s.cfg.ExitWindowStart.Duration && tod < s.cfg.ExitWindowEnd.Duration:
|
||||
return domain.StatePlaceExitOrders
|
||||
case tod >= s.cfg.ExitWindowEnd.Duration && tod < s.cfg.HardExitDeadline.Duration:
|
||||
return domain.StateMonitorExitOrders
|
||||
case tod >= s.cfg.HardExitDeadline.Duration && tod < s.cfg.EntrySignalTime.Duration:
|
||||
return domain.StateReconcile
|
||||
case tod >= s.cfg.EntrySignalTime.Duration && tod < s.cfg.EntryWindowStart.Duration:
|
||||
return domain.StateGenerateSignals
|
||||
case tod >= s.cfg.EntryWindowStart.Duration && tod < s.cfg.NoNewEntryAfter.Duration:
|
||||
return domain.StatePlaceEntryOrders
|
||||
case tod >= s.cfg.NoNewEntryAfter.Duration:
|
||||
return domain.StateHoldOvernight
|
||||
default:
|
||||
return domain.StateSleep
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) prepareSignals(ctx context.Context, now time.Time) error {
|
||||
if err := s.transitionSequence(ctx,
|
||||
domain.StateInit,
|
||||
domain.StateSyncInstruments,
|
||||
domain.StateSyncMarketData,
|
||||
domain.StateGenerateSignals,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.svc.Registry.SyncMetadata(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
tradeDate := tradingDate(now)
|
||||
instrumentsList, err := s.svc.Repo.ListInstruments(ctx, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.svc.MarketData.BackfillDaily(ctx, instrumentsList, tradeDate.AddDate(0, 0, -s.cfg.RollingLong-10), tradeDate); err != nil {
|
||||
return err
|
||||
}
|
||||
minuteFrom := s.cfg.EntryWindowStart.On(tradeDate, s.cfg.Location)
|
||||
minuteTo := s.cfg.ExitWindowEnd.On(tradeDate.AddDate(0, 0, 1), s.cfg.Location)
|
||||
if err := s.svc.MarketData.BackfillMinute(ctx, instrumentsList, minuteFrom, minuteTo); err != nil {
|
||||
s.logWarn("minute backfill failed; liquidity will fall back to ADV", "err", err)
|
||||
}
|
||||
if err := s.applySizeReductionRule(ctx, tradeDate, false); err != nil {
|
||||
return err
|
||||
}
|
||||
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
openPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, instrument := range instrumentsList {
|
||||
if err := s.generateInstrumentSignal(ctx, now, tradeDate, portfolio, len(openPositions), instrument); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.transitionTo(ctx, domain.StateWaitEntryWindow)
|
||||
}
|
||||
|
||||
func (s Scheduler) generateInstrumentSignal(ctx context.Context, now, tradeDate time.Time, portfolio domain.Portfolio, openPositionCount int, instrument domain.Instrument) error {
|
||||
book, err := s.svc.MarketData.LatestQuote(ctx, instrument.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
|
||||
if err != nil {
|
||||
return s.saveRejectedSignal(ctx, tradeDate, instrument, "quote_unavailable", err)
|
||||
}
|
||||
spread, err := spreadFromBook(book, instrument.MinPriceIncrement)
|
||||
if err != nil {
|
||||
return s.saveRejectedSignal(ctx, tradeDate, instrument, "spread_unavailable", err)
|
||||
}
|
||||
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, instrument.InstrumentUID)
|
||||
if err != nil {
|
||||
tradingStatus = domain.TradingStatusUnknown
|
||||
}
|
||||
feature, err := s.svc.Features.Recompute(ctx, instrument, tradeDate, spread)
|
||||
if err != nil {
|
||||
return s.saveRejectedSignal(ctx, tradeDate, instrument, "features_unavailable", err)
|
||||
}
|
||||
remaining, err := s.svc.FreeOrders.Check(ctx, tradeDate, instrument, 1)
|
||||
freeOrderOK := err == nil
|
||||
sig := s.svc.Signals.Evaluate(signal.Candidate{
|
||||
Instrument: instrument,
|
||||
Features: feature,
|
||||
TradingStatus: tradingStatus,
|
||||
FreeOrderOK: freeOrderOK,
|
||||
OpenPositions: openPositionCount,
|
||||
TradeDate: tradeDate,
|
||||
ExtraContext: map[string]any{
|
||||
"free_orders_remaining": remaining,
|
||||
"quote_time": book.Time.Format(time.RFC3339),
|
||||
},
|
||||
})
|
||||
if sig.Decision == domain.DecisionEnter {
|
||||
sized, sizingErr := s.sizeSignal(ctx, portfolio, instrument, feature, book, 1)
|
||||
switch {
|
||||
case sizingErr != nil:
|
||||
sig.Decision = domain.DecisionReject
|
||||
sig.RejectReason = sizingErr.Error()
|
||||
case sized.Lots <= 0:
|
||||
sig.Decision = domain.DecisionReject
|
||||
sig.RejectReason = sized.Reason
|
||||
default:
|
||||
sig.TargetLots = sized.Lots
|
||||
sig.TargetNotional = sized.TargetNotional
|
||||
}
|
||||
}
|
||||
if err := s.svc.Repo.UpsertSignal(ctx, sig); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.notifySignal(ctx, now, sig)
|
||||
}
|
||||
|
||||
func (s Scheduler) saveRejectedSignal(ctx context.Context, tradeDate time.Time, instrument domain.Instrument, reason string, cause error) error {
|
||||
sig := domain.Signal{
|
||||
TradeDate: tradeDate,
|
||||
InstrumentUID: instrument.InstrumentUID,
|
||||
Decision: domain.DecisionReject,
|
||||
RejectReason: reason,
|
||||
ContextJSON: fmt.Sprintf(`{"error":%q}`, cause.Error()),
|
||||
CreatedAt: s.nowUTC(),
|
||||
}
|
||||
return s.svc.Repo.UpsertSignal(ctx, sig)
|
||||
}
|
||||
|
||||
func (s Scheduler) sizeSignal(_ context.Context, portfolio domain.Portfolio, instrument domain.Instrument, feature domain.FeatureSet, book domain.OrderBook, selected int) (risk.SizingResult, error) {
|
||||
bid, ask, err := bestBidAsk(book)
|
||||
if err != nil {
|
||||
return risk.SizingResult{}, err
|
||||
}
|
||||
price, err := execution.LimitBuyPrice(bid, ask, instrument.MinPriceIncrement, s.cfg.PassiveImproveTicks)
|
||||
if err != nil {
|
||||
return risk.SizingResult{}, err
|
||||
}
|
||||
return s.svc.Sizer.Size(risk.SizingInput{
|
||||
Portfolio: portfolio,
|
||||
SelectedInstruments: selected,
|
||||
LimitPrice: price,
|
||||
Lot: instrument.Lot,
|
||||
EntryIntervalVolume: feature.EntryIntervalVolume,
|
||||
ExitIntervalVolume: feature.ExitIntervalVolume,
|
||||
Q05OvernightAbs: money.Abs(feature.SigmaOn60).Mul(decimal.NewFromFloat(1.65)),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
|
||||
if err := s.transitionTo(ctx, domain.StatePlaceEntryOrders); err != nil {
|
||||
return err
|
||||
}
|
||||
tradeDate := tradingDate(now)
|
||||
signals, err := s.svc.Repo.ListSignals(ctx, tradeDate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, tradeDate, tradeDate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
openPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
instrumentByUID, err := s.instrumentMap(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, sig := range signals {
|
||||
if sig.Decision != domain.DecisionEnter || sig.TargetLots <= 0 || hasOrder(existing, sig.InstrumentUID, domain.SideBuy) {
|
||||
continue
|
||||
}
|
||||
instrument, ok := instrumentByUID[sig.InstrumentUID]
|
||||
if !ok {
|
||||
return fmt.Errorf("instrument %s is not in registry", sig.InstrumentUID)
|
||||
}
|
||||
book, err := s.svc.MarketData.LatestQuote(ctx, sig.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, sig.InstrumentUID)
|
||||
if err != nil {
|
||||
tradingStatus = domain.TradingStatusUnknown
|
||||
}
|
||||
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pre := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
|
||||
Portfolio: portfolio,
|
||||
OpenPositions: len(openPositions),
|
||||
TradingStatus: tradingStatus,
|
||||
QuoteReceivedAt: book.ReceivedAt,
|
||||
Now: now.UTC(),
|
||||
MarketClose: s.cfg.EntryWindowEnd.On(now, s.cfg.Location).UTC(),
|
||||
})
|
||||
if !pre.Allowed {
|
||||
if err := s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
|
||||
Severity: domain.SeverityWarn,
|
||||
EventType: "pre_trade_reject",
|
||||
InstrumentUID: sig.InstrumentUID,
|
||||
Message: pre.Reason,
|
||||
ContextJSON: "{}",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
placed, err := s.svc.Execution.PlaceEntry(ctx, s.svc.AccountIDHash, instrument, tradeDate, sig.TargetLots, book, s.cfg.PassiveImproveTicks, 1)
|
||||
if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) {
|
||||
return err
|
||||
}
|
||||
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("entry order %s %s lots=%d status=%s", instrument.Ticker, placed.Side, placed.QuantityLots, placed.Status))
|
||||
existing = append(existing, placed)
|
||||
}
|
||||
return s.transitionTo(ctx, domain.StateMonitorEntryOrders)
|
||||
}
|
||||
|
||||
func (s Scheduler) monitorEntryOrders(ctx context.Context, now time.Time) error {
|
||||
if err := s.transitionTo(ctx, domain.StateMonitorEntryOrders); err != nil {
|
||||
return err
|
||||
}
|
||||
orders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
instrumentByUID, err := s.instrumentMap(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deadline := s.cfg.NoNewEntryAfter.On(now, s.cfg.Location).UTC()
|
||||
for _, order := range orders {
|
||||
if order.Side != domain.SideBuy || order.BrokerOrderID == "" {
|
||||
continue
|
||||
}
|
||||
instrument, ok := instrumentByUID[order.InstrumentUID]
|
||||
if !ok {
|
||||
return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID)
|
||||
}
|
||||
monitored, err := s.svc.Execution.MonitorUntil(ctx, order, execution.MonitorConfig{
|
||||
Deadline: deadline,
|
||||
PollInterval: s.cfg.OrderPollInterval,
|
||||
MaxAttempts: s.cfg.MaxEntryOrderAttempts,
|
||||
RepostAfter: repostAfter(now, deadline, s.cfg.MaxEntryOrderAttempts, s.cfg.OrderPollInterval),
|
||||
Instrument: instrument,
|
||||
ImproveTicks: s.cfg.PassiveImproveTicks,
|
||||
Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) {
|
||||
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if monitored.FilledLots > order.FilledLots || monitored.Commission.GreaterThan(order.Commission) {
|
||||
pos, err := s.svc.Positions.OnEntryFill(ctx, s.svc.AccountIDHash, instrument, monitored)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("entry fill %s lots=%d status=%s", monitored.InstrumentUID, monitored.FilledLots, pos.Status))
|
||||
}
|
||||
}
|
||||
if sinceMidnight(s.nowUTC().In(s.cfg.Location)) >= s.cfg.NoNewEntryAfter.Duration {
|
||||
if err := s.cancelActiveOrders(ctx, domain.SideBuy, domain.OrderStatusCancelled, "entry_window_closed"); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.transitionTo(ctx, domain.StateHoldOvernight)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Scheduler) waitExit(ctx context.Context, _ time.Time) error {
|
||||
return s.transitionTo(ctx, domain.StateWaitExitWindow)
|
||||
}
|
||||
|
||||
func (s Scheduler) holdOvernight(ctx context.Context) error {
|
||||
if err := s.cancelActiveOrders(ctx, domain.SideBuy, domain.OrderStatusCancelled, "entry_window_closed"); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.transitionTo(ctx, domain.StateHoldOvernight)
|
||||
}
|
||||
|
||||
func (s Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
|
||||
if err := s.transitionTo(ctx, domain.StatePlaceExitOrders); err != nil {
|
||||
return err
|
||||
}
|
||||
positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, tradingDate(now).AddDate(0, 0, -1), tradingDate(now))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
instrumentByUID, err := s.instrumentMap(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, pos := range positionsList {
|
||||
if pos.Lots <= 0 || hasOrder(existing, pos.InstrumentUID, domain.SideSell) {
|
||||
continue
|
||||
}
|
||||
instrument, ok := instrumentByUID[pos.InstrumentUID]
|
||||
if !ok {
|
||||
return fmt.Errorf("instrument %s is not in registry", pos.InstrumentUID)
|
||||
}
|
||||
book, err := s.svc.MarketData.LatestQuote(ctx, pos.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, pos.InstrumentUID)
|
||||
if err != nil {
|
||||
tradingStatus = domain.TradingStatusUnknown
|
||||
}
|
||||
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pre := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
|
||||
Portfolio: portfolio,
|
||||
OpenPositions: len(positionsList),
|
||||
TradingStatus: tradingStatus,
|
||||
QuoteReceivedAt: book.ReceivedAt,
|
||||
Now: now.UTC(),
|
||||
MarketClose: s.cfg.HardExitDeadline.On(now, s.cfg.Location).UTC(),
|
||||
})
|
||||
if !pre.Allowed {
|
||||
return fmt.Errorf("exit pre-trade rejected: %s", pre.Reason)
|
||||
}
|
||||
placed, err := s.svc.Execution.PlaceExit(ctx, s.svc.AccountIDHash, instrument, pos.OpenTradeDate, pos.Lots, book, s.cfg.PassiveImproveTicks, 1)
|
||||
if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) {
|
||||
return err
|
||||
}
|
||||
pos.Status = domain.PositionExitOrderSent
|
||||
if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("exit order %s lots=%d status=%s", instrument.Ticker, placed.QuantityLots, placed.Status))
|
||||
existing = append(existing, placed)
|
||||
}
|
||||
return s.transitionTo(ctx, domain.StateMonitorExitOrders)
|
||||
}
|
||||
|
||||
func (s Scheduler) monitorExitOrders(ctx context.Context, now time.Time) error {
|
||||
if err := s.transitionTo(ctx, domain.StateMonitorExitOrders); err != nil {
|
||||
return err
|
||||
}
|
||||
orders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
openPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
positionByInstrument := make(map[string]domain.Position, len(openPositions))
|
||||
for _, pos := range openPositions {
|
||||
positionByInstrument[pos.InstrumentUID] = pos
|
||||
}
|
||||
instrumentByUID, err := s.instrumentMap(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deadline := s.cfg.HardExitDeadline.On(now, s.cfg.Location).UTC()
|
||||
for _, order := range orders {
|
||||
if order.Side != domain.SideSell || order.BrokerOrderID == "" {
|
||||
continue
|
||||
}
|
||||
instrument, ok := instrumentByUID[order.InstrumentUID]
|
||||
if !ok {
|
||||
return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID)
|
||||
}
|
||||
monitored, err := s.svc.Execution.MonitorUntil(ctx, order, execution.MonitorConfig{
|
||||
Deadline: deadline,
|
||||
PollInterval: s.cfg.OrderPollInterval,
|
||||
MaxAttempts: s.cfg.MaxExitOrderAttempts,
|
||||
RepostAfter: repostAfter(now, deadline, s.cfg.MaxExitOrderAttempts, s.cfg.OrderPollInterval),
|
||||
Instrument: instrument,
|
||||
ImproveTicks: s.cfg.PassiveImproveTicks,
|
||||
Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) {
|
||||
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if monitored.FilledLots > order.FilledLots || monitored.Commission.GreaterThan(order.Commission) {
|
||||
fill := exitFillDelta(order, monitored)
|
||||
if fill.FilledLots <= 0 && fill.Commission.IsZero() {
|
||||
continue
|
||||
}
|
||||
pos, ok := positionByInstrument[monitored.InstrumentUID]
|
||||
if !ok {
|
||||
return fmt.Errorf("exit fill for unknown local position %s", monitored.InstrumentUID)
|
||||
}
|
||||
updated, err := s.svc.Positions.OnExitFill(ctx, pos, fill)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
positionByInstrument[monitored.InstrumentUID] = updated
|
||||
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("exit fill %s lots=%d status=%s pnl=%s", monitored.InstrumentUID, monitored.FilledLots, updated.Status, updated.NetPnL.StringFixed(2)))
|
||||
}
|
||||
}
|
||||
if sinceMidnight(s.nowUTC().In(s.cfg.Location)) >= s.cfg.HardExitDeadline.Duration {
|
||||
return s.failOpenPositionsAtHardDeadline(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) reconcileAndReport(ctx context.Context, now time.Time) error {
|
||||
if err := s.transitionTo(ctx, domain.StateReconcile); err != nil {
|
||||
return err
|
||||
}
|
||||
diffs, err := s.svc.Reconcile.Run(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reconciliation.HasCritical(diffs) {
|
||||
return s.halt(ctx, "reconciliation_critical", "critical reconciliation diff", "")
|
||||
}
|
||||
tradeDate := tradingDate(now)
|
||||
sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sent {
|
||||
s.logWarn("daily report already sent; skipping duplicate", "date", tradeDate.Format("2006-01-02"))
|
||||
return s.transitionTo(ctx, domain.StateSleep)
|
||||
}
|
||||
signals, err := s.svc.Repo.ListSignals(ctx, tradeDate)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
positionsList, err := s.svc.Repo.ListPositions(ctx, s.svc.AccountIDHash, tradeDate.AddDate(0, 0, -1), tradeDate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.applySizeReductionRule(ctx, tradeDate, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.transitionTo(ctx, domain.StateReport); err != nil {
|
||||
return err
|
||||
}
|
||||
msg := report.ComposeDaily(report.DailyInput{
|
||||
Date: tradeDate,
|
||||
Mode: s.cfg.Mode,
|
||||
Signals: signals,
|
||||
Positions: positionsList,
|
||||
RiskStatus: "ok",
|
||||
})
|
||||
if err := s.svc.Notifier.Report(ctx, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.svc.Repo.MarkDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.transitionTo(ctx, domain.StateSleep)
|
||||
}
|
||||
|
||||
func (s *Scheduler) applySizeReductionRule(ctx context.Context, tradeDate time.Time, emitEvent bool) error {
|
||||
averageError, count, ok, err := s.averageExpectedErrorBps(ctx, tradeDate, sizeReductionWindowTrades)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok || count < sizeReductionWindowTrades || averageError.GreaterThanOrEqual(decimal.NewFromInt(-10)) {
|
||||
s.svc.Sizer = s.svc.Sizer.WithSizeFactor(decimal.NewFromInt(1))
|
||||
return nil
|
||||
}
|
||||
factor := decimal.NewFromFloat(sizeReductionFactor)
|
||||
s.svc.Sizer = s.svc.Sizer.WithSizeFactor(factor)
|
||||
if !emitEvent {
|
||||
return nil
|
||||
}
|
||||
return s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
|
||||
Severity: domain.SeverityWarn,
|
||||
EventType: "size_reduction_rule_triggered",
|
||||
Message: fmt.Sprintf("average expected_error_bps over %d trades is %s; sizing factor set to %s", count, averageError.StringFixed(2), factor.String()),
|
||||
ContextJSON: fmt.Sprintf(`{"average_expected_error_bps":%q,"trades":%d,"size_factor":%q}`, averageError.String(), count, factor.String()),
|
||||
})
|
||||
}
|
||||
|
||||
func (s Scheduler) averageExpectedErrorBps(ctx context.Context, tradeDate time.Time, limit int) (decimal.Decimal, int, bool, error) {
|
||||
if limit <= 0 {
|
||||
return decimal.Zero, 0, false, nil
|
||||
}
|
||||
positionsList, err := s.svc.Repo.ListPositions(ctx, s.svc.AccountIDHash, tradeDate.AddDate(0, 0, -120), tradeDate)
|
||||
if err != nil {
|
||||
return decimal.Zero, 0, false, err
|
||||
}
|
||||
sort.Slice(positionsList, func(i, j int) bool {
|
||||
return positionsList[i].UpdatedAt.After(positionsList[j].UpdatedAt)
|
||||
})
|
||||
signalsByDate := make(map[string][]domain.Signal)
|
||||
var errorsBps []decimal.Decimal
|
||||
for _, pos := range positionsList {
|
||||
if pos.Status != domain.PositionExitFilled {
|
||||
continue
|
||||
}
|
||||
key := tradingDate(pos.OpenTradeDate).Format("2006-01-02")
|
||||
signals, ok := signalsByDate[key]
|
||||
if !ok {
|
||||
signals, err = s.svc.Repo.ListSignals(ctx, tradingDate(pos.OpenTradeDate))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return decimal.Zero, 0, false, err
|
||||
}
|
||||
signalsByDate[key] = signals
|
||||
}
|
||||
for _, sig := range signals {
|
||||
if sig.InstrumentUID != pos.InstrumentUID || sig.Decision != domain.DecisionEnter {
|
||||
continue
|
||||
}
|
||||
errorsBps = append(errorsBps, pos.RealizedEdgeBps.Sub(sig.NetEdgeBps))
|
||||
break
|
||||
}
|
||||
if len(errorsBps) == limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(errorsBps) == 0 {
|
||||
return decimal.Zero, 0, false, nil
|
||||
}
|
||||
sum := decimal.Zero
|
||||
for _, value := range errorsBps {
|
||||
sum = sum.Add(value)
|
||||
}
|
||||
return sum.Div(decimal.NewFromInt(int64(len(errorsBps)))), len(errorsBps), true, nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) checkInfrastructure(ctx context.Context) error {
|
||||
if s.cfg.MaxClockDrift <= 0 || s.svc.Gateway == nil {
|
||||
return nil
|
||||
}
|
||||
serverTime, err := s.svc.Gateway.GetServerTime(ctx)
|
||||
if err != nil {
|
||||
if s.cfg.Mode == domain.ModePaper {
|
||||
return nil
|
||||
}
|
||||
return s.recordInfrastructureFailure(fmt.Errorf("server_time_unavailable: %w", err))
|
||||
}
|
||||
drift := timeutil.Drift(s.nowUTC(), serverTime)
|
||||
if drift > s.cfg.MaxClockDrift {
|
||||
return s.recordInfrastructureFailure(fmt.Errorf("server_clock_drift_too_high: %s > %s", drift, s.cfg.MaxClockDrift))
|
||||
}
|
||||
s.infraFailedSince = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) recordInfrastructureFailure(err error) error {
|
||||
now := s.nowUTC()
|
||||
if s.infraFailedSince.IsZero() {
|
||||
s.infraFailedSince = now
|
||||
s.logWarn("infrastructure check failed; waiting for outage threshold", "err", err, "threshold", s.cfg.APIOutageHalt)
|
||||
return nil
|
||||
}
|
||||
if s.cfg.APIOutageHalt <= 0 || now.Sub(s.infraFailedSince) >= s.cfg.APIOutageHalt {
|
||||
return err
|
||||
}
|
||||
s.logWarn("infrastructure check still failing", "err", err, "elapsed", now.Sub(s.infraFailedSince), "threshold", s.cfg.APIOutageHalt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Scheduler) cancelActiveOrders(ctx context.Context, side domain.Side, fallbackStatus domain.OrderStatus, reason string) error {
|
||||
orders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cancelled := 0
|
||||
for _, order := range orders {
|
||||
if order.Side != side {
|
||||
continue
|
||||
}
|
||||
if order.BrokerOrderID != "" && s.cfg.Mode.AllowsBrokerOrders() {
|
||||
if err := s.svc.Execution.Cancel(ctx, order); err != nil {
|
||||
return fmt.Errorf("cancel %s order %s: %w", side, order.ClientOrderID, err)
|
||||
}
|
||||
cancelled++
|
||||
continue
|
||||
}
|
||||
if err := s.svc.Repo.UpdateOrderStatus(ctx, order.ClientOrderID, fallbackStatus, order.FilledLots, order.RawStateJSON); err != nil {
|
||||
return fmt.Errorf("mark %s order %s %s: %w", side, order.ClientOrderID, fallbackStatus, err)
|
||||
}
|
||||
cancelled++
|
||||
}
|
||||
if cancelled == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
|
||||
Severity: domain.SeverityWarn,
|
||||
EventType: reason,
|
||||
Message: fmt.Sprintf("cancelled %d active %s orders at window boundary", cancelled, side),
|
||||
ContextJSON: "{}",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Scheduler) failOpenPositionsAtHardDeadline(ctx context.Context) error {
|
||||
if err := s.cancelActiveOrders(ctx, domain.SideSell, domain.OrderStatusExpired, "hard_exit_deadline_cancel"); err != nil {
|
||||
return err
|
||||
}
|
||||
positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var failed []domain.Position
|
||||
now := s.nowUTC()
|
||||
for _, pos := range positionsList {
|
||||
switch pos.Status {
|
||||
case domain.PositionHoldingOvernight, domain.PositionExitPartiallyFilled, domain.PositionExitOrderSent:
|
||||
pos.Status = domain.PositionExitFailed
|
||||
pos.UpdatedAt = now
|
||||
if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil {
|
||||
return err
|
||||
}
|
||||
failed = append(failed, pos)
|
||||
_ = s.svc.Notifier.Alert(ctx, fmt.Sprintf("exit_failed: %s lots=%d", pos.InstrumentUID, pos.Lots))
|
||||
default:
|
||||
}
|
||||
}
|
||||
if len(failed) == 0 {
|
||||
return s.reconcileAndReport(ctx, s.nowUTC().In(s.cfg.Location))
|
||||
}
|
||||
return s.svc.Risk.Halt(ctx, s.cfg.Mode, "hard_exit_deadline_missed", fmt.Sprintf("%d positions remain open after hard deadline", len(failed)), "")
|
||||
}
|
||||
|
||||
func (s Scheduler) nowUTC() time.Time {
|
||||
if s.clock != nil {
|
||||
return s.clock.Now().UTC()
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func repostAfter(now, deadline time.Time, attempts int, poll time.Duration) time.Duration {
|
||||
if attempts <= 1 {
|
||||
return 0
|
||||
}
|
||||
if poll <= 0 {
|
||||
poll = 500 * time.Millisecond
|
||||
}
|
||||
remaining := deadline.Sub(now)
|
||||
if remaining <= 0 {
|
||||
return poll
|
||||
}
|
||||
after := remaining / time.Duration(attempts)
|
||||
if after < poll {
|
||||
return poll
|
||||
}
|
||||
return after
|
||||
}
|
||||
|
||||
func (s Scheduler) transitionSequence(ctx context.Context, states ...domain.SystemState) error {
|
||||
for _, state := range states {
|
||||
if err := s.transitionTo(ctx, state); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Scheduler) transitionTo(ctx context.Context, to domain.SystemState) error {
|
||||
from, halted, reason, err := s.svc.Repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if halted || from == domain.StateHalted {
|
||||
return fmt.Errorf("%w: %s", statemachine.ErrSystemHalted, reason)
|
||||
}
|
||||
if from == to {
|
||||
return s.sm.Heartbeat(ctx, to)
|
||||
}
|
||||
if err := s.sm.Transition(ctx, from, to); err != nil {
|
||||
if errors.Is(err, statemachine.ErrIllegalTransition) {
|
||||
return s.sm.Heartbeat(ctx, to)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Scheduler) halt(ctx context.Context, eventType, reason, instrumentUID string) error {
|
||||
_ = s.svc.Notifier.Alert(ctx, fmt.Sprintf("%s: %s", eventType, reason))
|
||||
return s.svc.Risk.Halt(ctx, s.cfg.Mode, eventType, reason, instrumentUID)
|
||||
}
|
||||
|
||||
func (s Scheduler) notifySignal(ctx context.Context, _ time.Time, sig domain.Signal) error {
|
||||
return s.svc.Notifier.Info(ctx, fmt.Sprintf("signal %s decision=%s edge=%s reason=%s lots=%d", sig.InstrumentUID, sig.Decision, sig.NetEdgeBps.StringFixed(2), sig.RejectReason, sig.TargetLots))
|
||||
}
|
||||
|
||||
func (s Scheduler) instrumentMap(ctx context.Context) (map[string]domain.Instrument, error) {
|
||||
instrumentsList, err := s.svc.Repo.ListInstruments(ctx, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make(map[string]domain.Instrument, len(instrumentsList))
|
||||
for _, instrument := range instrumentsList {
|
||||
out[instrument.InstrumentUID] = instrument
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s Scheduler) logWarn(msg string, args ...any) {
|
||||
if s.svc.Log != nil {
|
||||
s.svc.Log.Warn(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func exitFillDelta(previous, current domain.Order) domain.Order {
|
||||
fill := current
|
||||
fill.FilledLots = current.FilledLots - previous.FilledLots
|
||||
if fill.FilledLots < 0 {
|
||||
fill.FilledLots = 0
|
||||
}
|
||||
fill.Commission = current.Commission.Sub(previous.Commission)
|
||||
if fill.Commission.IsNegative() {
|
||||
fill.Commission = decimal.Zero
|
||||
}
|
||||
if fill.FilledLots > 0 {
|
||||
currentValue := current.AvgFillPrice.Mul(decimal.NewFromInt(current.FilledLots))
|
||||
previousValue := previous.AvgFillPrice.Mul(decimal.NewFromInt(previous.FilledLots))
|
||||
fill.AvgFillPrice = currentValue.Sub(previousValue).Div(decimal.NewFromInt(fill.FilledLots))
|
||||
}
|
||||
return fill
|
||||
}
|
||||
|
||||
func spreadFromBook(book domain.OrderBook, tick decimal.Decimal) (features.SpreadResult, error) {
|
||||
bid, ask, err := bestBidAsk(book)
|
||||
if err != nil {
|
||||
return features.SpreadResult{}, err
|
||||
}
|
||||
return features.Spread(bid, ask, tick)
|
||||
}
|
||||
|
||||
func bestBidAsk(book domain.OrderBook) (decimal.Decimal, decimal.Decimal, error) {
|
||||
bid, ok := book.BestBid()
|
||||
if !ok {
|
||||
return decimal.Zero, decimal.Zero, execution.ErrEmptyOrderBook
|
||||
}
|
||||
ask, ok := book.BestAsk()
|
||||
if !ok {
|
||||
return decimal.Zero, decimal.Zero, execution.ErrEmptyOrderBook
|
||||
}
|
||||
return bid, ask, nil
|
||||
}
|
||||
|
||||
func hasOrder(orders []domain.Order, instrumentUID string, side domain.Side) bool {
|
||||
for _, order := range orders {
|
||||
if order.InstrumentUID == instrumentUID && order.Side == side && order.Status != domain.OrderStatusFailed && order.Status != domain.OrderStatusRejected {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sinceMidnight(t time.Time) time.Duration {
|
||||
h, m, s := t.Clock()
|
||||
return time.Duration(h)*time.Hour + time.Duration(m)*time.Minute + time.Duration(s)*time.Second
|
||||
}
|
||||
|
||||
func tradingDate(t time.Time) time.Time {
|
||||
y, m, d := t.Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/execution"
|
||||
"overnight-trading-bot/internal/reconciliation"
|
||||
"overnight-trading-bot/internal/risk"
|
||||
"overnight-trading-bot/internal/statemachine"
|
||||
"overnight-trading-bot/internal/testutil"
|
||||
"overnight-trading-bot/internal/timeutil"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
func TestPhaseUsesMoscowWindows(t *testing.T) {
|
||||
loc := time.FixedZone("MSK", 3*60*60)
|
||||
s := Scheduler{cfg: Config{
|
||||
Location: loc,
|
||||
EntrySignalTime: mustTOD("18:10:00"),
|
||||
EntryWindowStart: mustTOD("18:20:00"),
|
||||
NoNewEntryAfter: mustTOD("18:38:30"),
|
||||
ExitWatchStart: mustTOD("09:50:00"),
|
||||
ExitWindowStart: mustTOD("10:05:00"),
|
||||
ExitWindowEnd: mustTOD("10:25:00"),
|
||||
HardExitDeadline: mustTOD("10:45:00"),
|
||||
}}
|
||||
tests := []struct {
|
||||
at string
|
||||
want domain.SystemState
|
||||
}{
|
||||
{"2026-06-06T09:55:00+03:00", domain.StateWaitExitWindow},
|
||||
{"2026-06-06T10:10:00+03:00", domain.StatePlaceExitOrders},
|
||||
{"2026-06-06T10:30:00+03:00", domain.StateMonitorExitOrders},
|
||||
{"2026-06-06T11:00:00+03:00", domain.StateReconcile},
|
||||
{"2026-06-06T18:15:00+03:00", domain.StateGenerateSignals},
|
||||
{"2026-06-06T18:25:00+03:00", domain.StatePlaceEntryOrders},
|
||||
{"2026-06-06T19:00:00+03:00", domain.StateHoldOvernight},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.at, func(t *testing.T) {
|
||||
at, err := time.Parse(time.RFC3339, tt.at)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := s.phase(at.In(loc)); got != tt.want {
|
||||
t.Fatalf("phase=%s, want %s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfrastructureOutageRequiresThreshold(t *testing.T) {
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
gateway.ServerTime = time.Now().UTC().Add(-10 * time.Second)
|
||||
s := &Scheduler{
|
||||
cfg: Config{
|
||||
Mode: domain.ModeSandbox,
|
||||
MaxClockDrift: 2 * time.Second,
|
||||
APIOutageHalt: 180 * time.Second,
|
||||
},
|
||||
svc: Services{Gateway: gateway},
|
||||
}
|
||||
if err := s.checkInfrastructure(context.Background()); err != nil {
|
||||
t.Fatalf("first infrastructure failure should be tolerated: %v", err)
|
||||
}
|
||||
s.infraFailedSince = time.Now().UTC().Add(-181 * time.Second)
|
||||
if err := s.checkInfrastructure(context.Background()); err == nil {
|
||||
t.Fatalf("expected outage after threshold")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileAndReportIsIdempotentPerDate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
notifier := &countNotifier{}
|
||||
recon := reconciliation.New(repo, gateway, "account", "hash")
|
||||
s := Scheduler{
|
||||
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
|
||||
sm: statemachine.New(repo, domain.ModePaper),
|
||||
svc: Services{
|
||||
Repo: repo,
|
||||
Gateway: gateway,
|
||||
Reconcile: recon,
|
||||
Notifier: notifier,
|
||||
Risk: risk.NewManager(repo, risk.ManagerConfig{}),
|
||||
AccountID: "account",
|
||||
AccountIDHash: "hash",
|
||||
},
|
||||
}
|
||||
now := time.Date(2026, 6, 7, 12, 0, 0, 0, time.UTC)
|
||||
if err := s.reconcileAndReport(ctx, now); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.reconcileAndReport(ctx, now); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if notifier.reports != 1 {
|
||||
t.Fatalf("reports sent=%d, want 1", notifier.reports)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitFillDeltaUsesOnlyNewlyExecutedLots(t *testing.T) {
|
||||
previous := domain.Order{
|
||||
FilledLots: 2,
|
||||
AvgFillPrice: decimal.NewFromInt(100),
|
||||
Commission: decimal.NewFromFloat(0.50),
|
||||
}
|
||||
current := domain.Order{
|
||||
FilledLots: 4,
|
||||
AvgFillPrice: decimal.NewFromInt(110),
|
||||
Commission: decimal.NewFromFloat(1.25),
|
||||
}
|
||||
fill := exitFillDelta(previous, current)
|
||||
if fill.FilledLots != 2 {
|
||||
t.Fatalf("delta filled lots=%d, want 2", fill.FilledLots)
|
||||
}
|
||||
if !fill.AvgFillPrice.Equal(decimal.NewFromInt(120)) {
|
||||
t.Fatalf("delta avg fill price=%s, want 120", fill.AvgFillPrice)
|
||||
}
|
||||
if !fill.Commission.Equal(decimal.NewFromFloat(0.75)) {
|
||||
t.Fatalf("delta commission=%s, want 0.75", fill.Commission)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDeadlineMarksOpenPositionFailedAndHalts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
openDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
|
||||
if err := repo.UpsertPosition(ctx, domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: openDate,
|
||||
Lots: 1,
|
||||
Lot: 1,
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
notifier := &countNotifier{}
|
||||
s := Scheduler{
|
||||
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
|
||||
svc: Services{
|
||||
Repo: repo,
|
||||
Risk: risk.NewManager(repo, risk.ManagerConfig{}),
|
||||
Notifier: notifier,
|
||||
AccountIDHash: "hash",
|
||||
},
|
||||
}
|
||||
if err := s.failOpenPositionsAtHardDeadline(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !repo.Halted || repo.State != domain.StateHalted {
|
||||
t.Fatalf("system not halted: state=%s halted=%v", repo.State, repo.Halted)
|
||||
}
|
||||
positions, err := repo.ListOpenPositions(ctx, "hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(positions) != 1 || positions[0].Status != domain.PositionExitFailed {
|
||||
t.Fatalf("positions=%+v, want EXIT_FAILED", positions)
|
||||
}
|
||||
if notifier.alerts != 1 {
|
||||
t.Fatalf("alerts=%d, want 1", notifier.alerts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHoldOvernightCancelsActiveBuyOrders(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
|
||||
if err := repo.UpsertOrder(ctx, domain.Order{
|
||||
ClientOrderID: "buy",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: tradeDate,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusNew,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := Scheduler{
|
||||
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
|
||||
sm: statemachine.New(repo, domain.ModePaper),
|
||||
svc: Services{
|
||||
Repo: repo,
|
||||
Execution: &execution.Engine{},
|
||||
AccountIDHash: "hash",
|
||||
},
|
||||
}
|
||||
if err := s.holdOvernight(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
orders, err := repo.ListOrders(ctx, "hash", tradeDate, tradeDate)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(orders) != 1 || orders[0].Status != domain.OrderStatusCancelled {
|
||||
t.Fatalf("orders=%+v, want CANCELLED", orders)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizeReductionRuleCutsSizerAfterBadExpectedErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
tradeDate := time.Date(2026, 6, 30, 0, 0, 0, 0, time.UTC)
|
||||
for i := 0; i < sizeReductionWindowTrades; i++ {
|
||||
date := tradeDate.AddDate(0, 0, -i)
|
||||
if err := repo.UpsertSignal(ctx, domain.Signal{
|
||||
TradeDate: date,
|
||||
InstrumentUID: "uid",
|
||||
Decision: domain.DecisionEnter,
|
||||
NetEdgeBps: decimal.NewFromInt(20),
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := repo.UpsertPosition(ctx, domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: date,
|
||||
Lot: 1,
|
||||
Status: domain.PositionExitFilled,
|
||||
RealizedEdgeBps: decimal.Zero,
|
||||
UpdatedAt: date.Add(time.Hour),
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
s := Scheduler{
|
||||
svc: Services{
|
||||
Repo: repo,
|
||||
AccountIDHash: "hash",
|
||||
Sizer: risk.NewSizer(risk.SizingConfig{
|
||||
MaxPositionPct: decimal.NewFromInt(1),
|
||||
MaxTotalExposurePct: decimal.NewFromInt(1),
|
||||
MaxParticipationRate: decimal.NewFromInt(1),
|
||||
CashUsageBuffer: decimal.NewFromInt(1),
|
||||
RiskBudgetPerInstrumentPct: decimal.NewFromInt(1),
|
||||
MinOrderNotionalRUB: decimal.NewFromInt(1),
|
||||
}),
|
||||
},
|
||||
}
|
||||
if err := s.applySizeReductionRule(ctx, tradeDate, true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sized := s.svc.Sizer.Size(risk.SizingInput{
|
||||
Portfolio: domain.Portfolio{Equity: decimal.NewFromInt(10_000), Cash: decimal.NewFromInt(10_000)},
|
||||
SelectedInstruments: 1,
|
||||
LimitPrice: decimal.NewFromInt(100),
|
||||
Lot: 1,
|
||||
EntryIntervalVolume: decimal.NewFromInt(10_000),
|
||||
ExitIntervalVolume: decimal.NewFromInt(10_000),
|
||||
Q05OvernightAbs: decimal.NewFromInt(1),
|
||||
})
|
||||
if sized.Lots != 50 {
|
||||
t.Fatalf("lots=%d, want reduced 50", sized.Lots)
|
||||
}
|
||||
if len(repo.RiskEvents) != 1 || repo.RiskEvents[0].EventType != "size_reduction_rule_triggered" {
|
||||
t.Fatalf("risk events=%+v", repo.RiskEvents)
|
||||
}
|
||||
}
|
||||
|
||||
func mustTOD(raw string) timeutil.TimeOfDay {
|
||||
tod, err := timeutil.ParseTimeOfDay(raw)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return tod
|
||||
}
|
||||
|
||||
type countNotifier struct {
|
||||
reports int
|
||||
alerts int
|
||||
}
|
||||
|
||||
func (n *countNotifier) Info(context.Context, string) error { return nil }
|
||||
func (n *countNotifier) Warn(context.Context, string) error { return nil }
|
||||
func (n *countNotifier) Alert(context.Context, string) error { n.alerts++; return nil }
|
||||
func (n *countNotifier) Report(context.Context, string) error { n.reports++; return nil }
|
||||
func (n *countNotifier) Close() error { return nil }
|
||||
@@ -0,0 +1,152 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
ReasonDisabled = "instrument_disabled"
|
||||
ReasonQuarantine = "instrument_quarantine"
|
||||
ReasonMetadataInvalid = "metadata_invalid"
|
||||
ReasonTradingStatus = "trading_status_not_normal"
|
||||
ReasonCommission = "commission_nonzero"
|
||||
ReasonMuShort = "mu_on_60_non_positive"
|
||||
ReasonMuLong = "mu_on_252_non_positive"
|
||||
ReasonSigmaZero = "sigma_on_60_zero"
|
||||
ReasonTStat = "tstat_on_60_below_threshold"
|
||||
ReasonWinRate = "win_on_60_below_threshold"
|
||||
ReasonNetEdge = "net_edge_bps_below_threshold"
|
||||
ReasonSpread = "spread_bps_above_limit"
|
||||
ReasonTick = "tick_bps_above_limit"
|
||||
ReasonADV = "adv_20_below_limit"
|
||||
ReasonFreeOrders = "free_order_budget_insufficient"
|
||||
ReasonMaxPositions = "max_positions_reached"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MinTStat60 decimal.Decimal
|
||||
MinWinRate60 decimal.Decimal
|
||||
MinNetEdgeBps decimal.Decimal
|
||||
MinADVRUB decimal.Decimal
|
||||
MaxSpreadBpsDefault decimal.Decimal
|
||||
MaxSpreadBpsMoneyMarket decimal.Decimal
|
||||
MaxSpreadBpsBondFunds decimal.Decimal
|
||||
MaxSpreadBpsEquityFunds decimal.Decimal
|
||||
MaxTickBps decimal.Decimal
|
||||
RequireZeroCommission bool
|
||||
MaxPositions int
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Instrument domain.Instrument
|
||||
Features domain.FeatureSet
|
||||
TradingStatus domain.TradingStatus
|
||||
FreeOrderOK bool
|
||||
OpenPositions int
|
||||
TradeDate time.Time
|
||||
ExtraContext map[string]any
|
||||
}
|
||||
|
||||
type Engine struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func New(cfg Config) Engine {
|
||||
return Engine{cfg: cfg}
|
||||
}
|
||||
|
||||
func (e Engine) Evaluate(c Candidate) domain.Signal {
|
||||
reason := e.firstRejectReason(c)
|
||||
decision := domain.DecisionEnter
|
||||
if reason != "" {
|
||||
decision = domain.DecisionReject
|
||||
}
|
||||
if isSkipReason(reason) {
|
||||
decision = domain.DecisionSkip
|
||||
}
|
||||
context := map[string]any{
|
||||
"ticker": c.Instrument.Ticker,
|
||||
"fund_type": c.Instrument.FundType,
|
||||
"trading_status": c.TradingStatus,
|
||||
"spread_limit": e.spreadLimit(c.Instrument).String(),
|
||||
}
|
||||
for k, v := range c.ExtraContext {
|
||||
context[k] = v
|
||||
}
|
||||
raw, _ := json.Marshal(context)
|
||||
return domain.Signal{
|
||||
TradeDate: c.TradeDate,
|
||||
InstrumentUID: c.Instrument.InstrumentUID,
|
||||
Decision: decision,
|
||||
Score: c.Features.NetEdgeBps,
|
||||
NetEdgeBps: c.Features.NetEdgeBps,
|
||||
RejectReason: reason,
|
||||
ContextJSON: string(raw),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func isSkipReason(reason string) bool {
|
||||
return reason == ReasonFreeOrders || reason == ReasonMaxPositions
|
||||
}
|
||||
|
||||
func (e Engine) firstRejectReason(c Candidate) string {
|
||||
instr := c.Instrument
|
||||
features := c.Features
|
||||
switch {
|
||||
case !instr.Enabled:
|
||||
return ReasonDisabled
|
||||
case instr.Quarantine:
|
||||
return ReasonQuarantine
|
||||
case !instr.MetadataValid():
|
||||
return ReasonMetadataInvalid
|
||||
case c.TradingStatus != domain.TradingStatusNormal:
|
||||
return ReasonTradingStatus
|
||||
case e.cfg.RequireZeroCommission && instr.ExpectedCommissionBpsPerSide.IsPositive():
|
||||
return ReasonCommission
|
||||
case !features.MuOn60.IsPositive():
|
||||
return ReasonMuShort
|
||||
case !features.MuOn252.IsPositive():
|
||||
return ReasonMuLong
|
||||
case !features.SigmaOn60.IsPositive():
|
||||
return ReasonSigmaZero
|
||||
case features.TStatOn60.LessThan(e.cfg.MinTStat60):
|
||||
return ReasonTStat
|
||||
case features.WinOn60.LessThan(e.cfg.MinWinRate60):
|
||||
return ReasonWinRate
|
||||
case features.NetEdgeBps.LessThan(e.cfg.MinNetEdgeBps):
|
||||
return ReasonNetEdge
|
||||
case features.SpreadBps.GreaterThan(e.spreadLimit(instr)):
|
||||
return ReasonSpread
|
||||
case features.TickBps.GreaterThan(e.cfg.MaxTickBps):
|
||||
return ReasonTick
|
||||
case features.ADV20.LessThan(e.cfg.MinADVRUB):
|
||||
return ReasonADV
|
||||
case !c.FreeOrderOK:
|
||||
return ReasonFreeOrders
|
||||
case e.cfg.MaxPositions > 0 && c.OpenPositions >= e.cfg.MaxPositions:
|
||||
return ReasonMaxPositions
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (e Engine) spreadLimit(instr domain.Instrument) decimal.Decimal {
|
||||
fundType := strings.ToLower(instr.FundType)
|
||||
switch {
|
||||
case strings.Contains(fundType, "money"):
|
||||
return e.cfg.MaxSpreadBpsMoneyMarket
|
||||
case strings.Contains(fundType, "bond"):
|
||||
return e.cfg.MaxSpreadBpsBondFunds
|
||||
case strings.Contains(fundType, "equity"):
|
||||
return e.cfg.MaxSpreadBpsEquityFunds
|
||||
default:
|
||||
return e.cfg.MaxSpreadBpsDefault
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func sd(raw string) decimal.Decimal {
|
||||
v, err := decimal.NewFromString(raw)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func baseCandidate() Candidate {
|
||||
return Candidate{
|
||||
TradeDate: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
Instrument: domain.Instrument{
|
||||
InstrumentUID: "uid",
|
||||
Ticker: "TRUR",
|
||||
ClassCode: "TQTF",
|
||||
Lot: 1,
|
||||
MinPriceIncrement: sd("0.01"),
|
||||
Currency: "RUB",
|
||||
Enabled: true,
|
||||
},
|
||||
Features: domain.FeatureSet{
|
||||
MuOn60: sd("0.002"),
|
||||
MuOn252: sd("0.001"),
|
||||
SigmaOn60: sd("0.01"),
|
||||
TStatOn60: sd("2"),
|
||||
WinOn60: sd("0.60"),
|
||||
NetEdgeBps: sd("20"),
|
||||
SpreadBps: sd("5"),
|
||||
TickBps: sd("1"),
|
||||
ADV20: sd("10000000"),
|
||||
},
|
||||
TradingStatus: domain.TradingStatusNormal,
|
||||
FreeOrderOK: true,
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineEnter(t *testing.T) {
|
||||
engine := New(Config{
|
||||
MinTStat60: sd("1.25"),
|
||||
MinWinRate60: sd("0.55"),
|
||||
MinNetEdgeBps: sd("10"),
|
||||
MinADVRUB: sd("5000000"),
|
||||
MaxSpreadBpsDefault: sd("20"),
|
||||
MaxSpreadBpsMoneyMarket: sd("5"),
|
||||
MaxSpreadBpsBondFunds: sd("10"),
|
||||
MaxSpreadBpsEquityFunds: sd("25"),
|
||||
MaxTickBps: sd("10"),
|
||||
RequireZeroCommission: true,
|
||||
MaxPositions: 5,
|
||||
})
|
||||
sig := engine.Evaluate(baseCandidate())
|
||||
if sig.Decision != domain.DecisionEnter || sig.RejectReason != "" {
|
||||
t.Fatalf("unexpected signal: %+v", sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineFirstRejectReason(t *testing.T) {
|
||||
engine := New(Config{MinTStat60: sd("1.25"), MinWinRate60: sd("0.55"), MinNetEdgeBps: sd("10"), MinADVRUB: sd("5000000"), MaxSpreadBpsDefault: sd("20"), MaxTickBps: sd("10"), RequireZeroCommission: true})
|
||||
c := baseCandidate()
|
||||
c.Features.MuOn60 = decimal.Zero
|
||||
c.Features.NetEdgeBps = decimal.Zero
|
||||
sig := engine.Evaluate(c)
|
||||
if sig.RejectReason != ReasonMuShort {
|
||||
t.Fatalf("reason=%s", sig.RejectReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineUsesSkipForCapacityReasons(t *testing.T) {
|
||||
engine := New(Config{MinTStat60: sd("1.25"), MinWinRate60: sd("0.55"), MinNetEdgeBps: sd("10"), MinADVRUB: sd("5000000"), MaxSpreadBpsDefault: sd("20"), MaxTickBps: sd("10"), RequireZeroCommission: true, MaxPositions: 1})
|
||||
c := baseCandidate()
|
||||
c.OpenPositions = 1
|
||||
sig := engine.Evaluate(c)
|
||||
if sig.Decision != domain.DecisionSkip || sig.RejectReason != ReasonMaxPositions {
|
||||
t.Fatalf("unexpected skip signal: %+v", sig)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package statemachine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/reconciliation"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIllegalTransition = errors.New("illegal system transition")
|
||||
ErrSystemHalted = errors.New("system is halted")
|
||||
)
|
||||
|
||||
type System struct {
|
||||
repo repository.Repository
|
||||
mode domain.Mode
|
||||
}
|
||||
|
||||
func New(repo repository.Repository, mode domain.Mode) System {
|
||||
return System{repo: repo, mode: mode}
|
||||
}
|
||||
|
||||
func (s System) Recover(ctx context.Context, reconcile reconciliation.Engine) (domain.SystemState, error) {
|
||||
state, halted, reason, err := s.repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if halted || state == domain.StateHalted {
|
||||
return domain.StateHalted, fmt.Errorf("system halted: %s", reason)
|
||||
}
|
||||
switch state {
|
||||
case domain.StatePlaceEntryOrders, domain.StateMonitorEntryOrders,
|
||||
domain.StatePlaceExitOrders, domain.StateMonitorExitOrders,
|
||||
domain.StateHoldOvernight:
|
||||
diffs, err := reconcile.Run(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if reconciliation.HasCritical(diffs) {
|
||||
if err := s.Halt(ctx, "critical reconciliation diff during recovery"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return domain.StateHalted, errors.New("critical reconciliation diff during recovery")
|
||||
}
|
||||
return state, nil
|
||||
case domain.StateInit, domain.StateSyncInstruments, domain.StateSyncMarketData, domain.StateGenerateSignals:
|
||||
return domain.StateInit, s.persist(ctx, domain.StateInit, false, "")
|
||||
default:
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s System) Transition(ctx context.Context, from, to domain.SystemState) error {
|
||||
current, halted, reason, err := s.repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if (halted || current == domain.StateHalted) && to != domain.StateHalted {
|
||||
return fmt.Errorf("%w: %s", ErrSystemHalted, reason)
|
||||
}
|
||||
if !legalTransition(from, to) {
|
||||
return fmt.Errorf("%w: %s -> %s", ErrIllegalTransition, from, to)
|
||||
}
|
||||
return s.persist(ctx, to, false, "")
|
||||
}
|
||||
|
||||
func (s System) Halt(ctx context.Context, reason string) error {
|
||||
return s.persist(ctx, domain.StateHalted, true, reason)
|
||||
}
|
||||
|
||||
func (s System) Heartbeat(ctx context.Context, state domain.SystemState) error {
|
||||
current, halted, reason, err := s.repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if halted || current == domain.StateHalted {
|
||||
return s.repo.SaveSystemState(ctx, domain.StateHalted, s.mode, true, reason, fmt.Sprintf(`{"heartbeat":"%s"}`, time.Now().UTC().Format(time.RFC3339Nano)))
|
||||
}
|
||||
return s.repo.SaveSystemState(ctx, state, s.mode, false, "", fmt.Sprintf(`{"heartbeat":"%s"}`, time.Now().UTC().Format(time.RFC3339Nano)))
|
||||
}
|
||||
|
||||
func (s System) persist(ctx context.Context, state domain.SystemState, halted bool, reason string) error {
|
||||
return s.repo.SaveSystemState(ctx, state, s.mode, halted, reason, "{}")
|
||||
}
|
||||
|
||||
func legalTransition(from, to domain.SystemState) bool {
|
||||
if from == to {
|
||||
return true
|
||||
}
|
||||
if to == domain.StateHalted {
|
||||
return true
|
||||
}
|
||||
allowed := map[domain.SystemState][]domain.SystemState{
|
||||
domain.StateInit: {domain.StateSyncInstruments, domain.StateWaitExitWindow},
|
||||
domain.StateSyncInstruments: {domain.StateSyncMarketData},
|
||||
domain.StateSyncMarketData: {domain.StateGenerateSignals},
|
||||
domain.StateGenerateSignals: {domain.StateWaitEntryWindow},
|
||||
domain.StateWaitEntryWindow: {domain.StatePlaceEntryOrders, domain.StateSleep},
|
||||
domain.StatePlaceEntryOrders: {domain.StateMonitorEntryOrders, domain.StateReconcile},
|
||||
domain.StateMonitorEntryOrders: {domain.StateHoldOvernight, domain.StateReconcile},
|
||||
domain.StateHoldOvernight: {domain.StateWaitExitWindow},
|
||||
domain.StateWaitExitWindow: {domain.StatePlaceExitOrders},
|
||||
domain.StatePlaceExitOrders: {domain.StateMonitorExitOrders, domain.StateReconcile},
|
||||
domain.StateMonitorExitOrders: {domain.StateReconcile},
|
||||
domain.StateReconcile: {domain.StateReport, domain.StateHalted},
|
||||
domain.StateReport: {domain.StateSleep},
|
||||
domain.StateSleep: {domain.StateInit, domain.StateWaitExitWindow, domain.StateGenerateSignals},
|
||||
}
|
||||
for _, candidate := range allowed[from] {
|
||||
if candidate == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package statemachine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/reconciliation"
|
||||
"overnight-trading-bot/internal/testutil"
|
||||
"overnight-trading-bot/internal/tinvest"
|
||||
)
|
||||
|
||||
func TestHeartbeatDoesNotClearHalt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
system := New(repo, domain.ModeLiveTrade)
|
||||
if err := system.Halt(ctx, "manual kill switch"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := system.Heartbeat(ctx, domain.StateSleep); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state, halted, reason, err := repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if state != domain.StateHalted || !halted || reason != "manual kill switch" {
|
||||
t.Fatalf("halt was not sticky: state=%s halted=%v reason=%q", state, halted, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionBlockedWhileHalted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
system := New(repo, domain.ModePaper)
|
||||
if err := system.Halt(ctx, "risk"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := system.Transition(ctx, domain.StateHalted, domain.StateInit)
|
||||
if !errors.Is(err, ErrSystemHalted) {
|
||||
t.Fatalf("expected ErrSystemHalted, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnhaltPreservesMode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
if err := repo.SaveSystemState(ctx, domain.StateHalted, domain.ModeLiveTrade, true, "risk", "{}"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := repo.Unhalt(ctx, "checked"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, halted, _, err := repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if halted || repo.Mode != domain.ModeLiveTrade {
|
||||
t.Fatalf("unhalt did not preserve mode: halted=%v mode=%s", halted, repo.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModePaper, false, "", "{}"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := repo.UpsertOrder(ctx, domain.Order{
|
||||
ClientOrderID: "local",
|
||||
BrokerOrderID: "broker-missing",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: time.Now().UTC(),
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusSent,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Minute),
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
system := New(repo, domain.ModePaper)
|
||||
state, err := system.Recover(ctx, reconciliation.New(repo, tinvest.NewFakeGateway(), "account", "hash"))
|
||||
if err == nil {
|
||||
t.Fatal("expected critical reconciliation error")
|
||||
}
|
||||
if state != domain.StateHalted || !repo.Halted {
|
||||
t.Fatalf("state=%s halted=%v, want HALTED", state, repo.Halted)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,344 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
)
|
||||
|
||||
var _ repository.Repository = (*MemoryRepository)(nil)
|
||||
|
||||
type MemoryRepository struct {
|
||||
mu sync.Mutex
|
||||
|
||||
Instruments map[string]domain.Instrument
|
||||
Daily map[string][]domain.Candle
|
||||
Minute map[string][]domain.Candle
|
||||
Features map[string]domain.FeatureSet
|
||||
Signals map[string]domain.Signal
|
||||
Orders map[string]domain.Order
|
||||
Positions map[int64]domain.Position
|
||||
RiskEvents []domain.RiskEvent
|
||||
FreeOrders map[string]int
|
||||
Reports map[string]bool
|
||||
|
||||
State domain.SystemState
|
||||
Mode domain.Mode
|
||||
Halted bool
|
||||
HaltReason string
|
||||
|
||||
nextPositionID int64
|
||||
}
|
||||
|
||||
func NewMemoryRepository() *MemoryRepository {
|
||||
return &MemoryRepository{
|
||||
Instruments: make(map[string]domain.Instrument),
|
||||
Daily: make(map[string][]domain.Candle),
|
||||
Minute: make(map[string][]domain.Candle),
|
||||
Features: make(map[string]domain.FeatureSet),
|
||||
Signals: make(map[string]domain.Signal),
|
||||
Orders: make(map[string]domain.Order),
|
||||
Positions: make(map[int64]domain.Position),
|
||||
FreeOrders: make(map[string]int),
|
||||
Reports: make(map[string]bool),
|
||||
State: domain.StateInit,
|
||||
Mode: domain.ModePaper,
|
||||
nextPositionID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) RunInTx(ctx context.Context, fn func(context.Context, repository.Repository) error) error {
|
||||
return fn(ctx, r)
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpsertInstrument(_ context.Context, instrument domain.Instrument) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.Instruments[instrument.InstrumentUID] = instrument
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ReplaceInstrument(_ context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.Instruments, oldInstrumentUID)
|
||||
r.Instruments[instrument.InstrumentUID] = instrument
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListInstruments(_ context.Context, includeDisabled bool) ([]domain.Instrument, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]domain.Instrument, 0, len(r.Instruments))
|
||||
for _, instrument := range r.Instruments {
|
||||
if includeDisabled || instrument.Enabled {
|
||||
out = append(out, instrument)
|
||||
}
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].Ticker < out[j].Ticker })
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) QuarantineInstrument(_ context.Context, instrumentUID, reason string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
instrument := r.Instruments[instrumentUID]
|
||||
instrument.Quarantine = true
|
||||
instrument.QuarantineReason = reason
|
||||
r.Instruments[instrumentUID] = instrument
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpsertDailyCandles(_ context.Context, candles []domain.Candle) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, candle := range candles {
|
||||
r.Daily[candle.InstrumentUID] = upsertCandle(r.Daily[candle.InstrumentUID], candle, false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListDailyCandles(_ context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return filterCandles(r.Daily[instrumentUID], from, to), nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpsertMinuteCandles(_ context.Context, candles []domain.Candle) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, candle := range candles {
|
||||
r.Minute[candle.InstrumentUID] = upsertCandle(r.Minute[candle.InstrumentUID], candle, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListMinuteCandles(_ context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return filterCandles(r.Minute[instrumentUID], from, to), nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpsertFeature(_ context.Context, feature domain.FeatureSet) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.Features[featureKey(feature.InstrumentUID, feature.TradeDate)] = feature
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) GetFeature(_ context.Context, instrumentUID string, tradeDate time.Time) (domain.FeatureSet, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.Features[featureKey(instrumentUID, tradeDate)], nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpsertSignal(_ context.Context, signal domain.Signal) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.Signals[featureKey(signal.InstrumentUID, signal.TradeDate)] = signal
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListSignals(_ context.Context, tradeDate time.Time) ([]domain.Signal, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var out []domain.Signal
|
||||
for _, signal := range r.Signals {
|
||||
if sameDate(signal.TradeDate, tradeDate) {
|
||||
out = append(out, signal)
|
||||
}
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].InstrumentUID < out[j].InstrumentUID })
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpsertOrder(_ context.Context, order domain.Order) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.Orders[order.ClientOrderID] = order
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpdateOrderStatus(_ context.Context, clientOrderID string, status domain.OrderStatus, filledLots int64, rawJSON string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
order := r.Orders[clientOrderID]
|
||||
order.Status = status
|
||||
order.FilledLots = filledLots
|
||||
order.RawStateJSON = rawJSON
|
||||
r.Orders[clientOrderID] = order
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListActiveOrders(_ context.Context, accountIDHash string) ([]domain.Order, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var out []domain.Order
|
||||
for _, order := range r.Orders {
|
||||
if order.AccountIDHash == accountIDHash && (order.Status == domain.OrderStatusNew || order.Status == domain.OrderStatusSent || order.Status == domain.OrderStatusPartiallyFilled) {
|
||||
out = append(out, order)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListOrders(_ context.Context, accountIDHash string, from, to time.Time) ([]domain.Order, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var out []domain.Order
|
||||
for _, order := range r.Orders {
|
||||
if order.AccountIDHash == accountIDHash && !order.TradeDate.Before(dateOnly(from)) && !order.TradeDate.After(dateOnly(to)) {
|
||||
out = append(out, order)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) UpsertPosition(_ context.Context, position domain.Position) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for id, existing := range r.Positions {
|
||||
if existing.AccountIDHash == position.AccountIDHash &&
|
||||
existing.InstrumentUID == position.InstrumentUID &&
|
||||
sameDate(existing.OpenTradeDate, position.OpenTradeDate) {
|
||||
position.ID = id
|
||||
r.Positions[id] = position
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if position.ID == 0 {
|
||||
position.ID = r.nextPositionID
|
||||
r.nextPositionID++
|
||||
}
|
||||
r.Positions[position.ID] = position
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListOpenPositions(_ context.Context, accountIDHash string) ([]domain.Position, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var out []domain.Position
|
||||
for _, pos := range r.Positions {
|
||||
if pos.AccountIDHash == accountIDHash && pos.Status != domain.PositionNoPosition && pos.Status != domain.PositionExitFilled && pos.Status != domain.PositionQuarantine {
|
||||
out = append(out, pos)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) ListPositions(_ context.Context, accountIDHash string, from, to time.Time) ([]domain.Position, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var out []domain.Position
|
||||
for _, pos := range r.Positions {
|
||||
if pos.AccountIDHash == accountIDHash && !pos.OpenTradeDate.Before(dateOnly(from)) && !pos.OpenTradeDate.After(dateOnly(to)) {
|
||||
out = append(out, pos)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) InsertRiskEvent(_ context.Context, event domain.RiskEvent) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.RiskEvents = append(r.RiskEvents, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) GetFreeOrdersSent(_ context.Context, tradeDate time.Time, instrumentUID string) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.FreeOrders[featureKey(instrumentUID, tradeDate)], nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) IncrementFreeOrders(_ context.Context, tradeDate time.Time, instrumentUID string, delta int) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.FreeOrders[featureKey(instrumentUID, tradeDate)] += delta
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) GetSystemState(_ context.Context) (domain.SystemState, bool, string, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.State, r.Halted, r.HaltReason, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) SaveSystemState(_ context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, _ string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.State = state
|
||||
r.Mode = mode
|
||||
r.Halted = halted
|
||||
r.HaltReason = reason
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) Unhalt(_ context.Context, reason string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if !r.Halted && r.State != domain.StateHalted {
|
||||
return fmt.Errorf("system is not halted")
|
||||
}
|
||||
r.RiskEvents = append(r.RiskEvents, domain.RiskEvent{Severity: domain.SeverityInfo, EventType: "manual_unhalt", Message: reason})
|
||||
r.State = domain.StateInit
|
||||
r.Halted = false
|
||||
r.HaltReason = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) WasDailyReportSent(_ context.Context, reportDate time.Time, accountIDHash string) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.Reports[accountIDHash+"|"+dateOnly(reportDate).Format("2006-01-02")], nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) MarkDailyReportSent(_ context.Context, reportDate time.Time, accountIDHash string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.Reports[accountIDHash+"|"+dateOnly(reportDate).Format("2006-01-02")] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) InsertReconciliation(_ context.Context, _ time.Time, _ string, _ bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func upsertCandle(candles []domain.Candle, candle domain.Candle, minute bool) []domain.Candle {
|
||||
for i, existing := range candles {
|
||||
if (!minute && sameDate(existing.TradeDate, candle.TradeDate)) || (minute && existing.TradeDate.Equal(candle.TradeDate)) {
|
||||
candles[i] = candle
|
||||
return candles
|
||||
}
|
||||
}
|
||||
return append(candles, candle)
|
||||
}
|
||||
|
||||
func filterCandles(candles []domain.Candle, from, to time.Time) []domain.Candle {
|
||||
var out []domain.Candle
|
||||
for _, candle := range candles {
|
||||
if !candle.TradeDate.Before(from) && !candle.TradeDate.After(to) {
|
||||
out = append(out, candle)
|
||||
}
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].TradeDate.Before(out[j].TradeDate) })
|
||||
return out
|
||||
}
|
||||
|
||||
func featureKey(instrumentUID string, date time.Time) string {
|
||||
return instrumentUID + "|" + dateOnly(date).Format("2006-01-02")
|
||||
}
|
||||
|
||||
func sameDate(a, b time.Time) bool {
|
||||
return dateOnly(a).Equal(dateOnly(b))
|
||||
}
|
||||
|
||||
func dateOnly(t time.Time) time.Time {
|
||||
y, m, d := t.UTC().Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package timeutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
Sleep(ctxDone <-chan struct{}, d time.Duration) bool
|
||||
}
|
||||
|
||||
type RealClock struct {
|
||||
Loc *time.Location
|
||||
}
|
||||
|
||||
func (c RealClock) Now() time.Time {
|
||||
now := time.Now()
|
||||
if c.Loc != nil {
|
||||
return now.In(c.Loc)
|
||||
}
|
||||
return now
|
||||
}
|
||||
|
||||
func (c RealClock) Sleep(ctxDone <-chan struct{}, d time.Duration) bool {
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-timer.C:
|
||||
return true
|
||||
case <-ctxDone:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type TimeOfDay struct {
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
func ParseTimeOfDay(raw string) (TimeOfDay, error) {
|
||||
parsed, err := time.Parse("15:04:05", raw)
|
||||
if err != nil {
|
||||
return TimeOfDay{}, fmt.Errorf("parse time of day %q: %w", raw, err)
|
||||
}
|
||||
return TimeOfDay{
|
||||
Duration: time.Duration(parsed.Hour())*time.Hour +
|
||||
time.Duration(parsed.Minute())*time.Minute +
|
||||
time.Duration(parsed.Second())*time.Second,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TimeOfDay) UnmarshalText(text []byte) error {
|
||||
parsed, err := ParseTimeOfDay(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*t = parsed
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t TimeOfDay) String() string {
|
||||
total := int64(t.Duration.Seconds())
|
||||
h := total / 3600
|
||||
m := (total % 3600) / 60
|
||||
s := total % 60
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
|
||||
func (t TimeOfDay) On(date time.Time, loc *time.Location) time.Time {
|
||||
local := date.In(loc)
|
||||
y, m, d := local.Date()
|
||||
midnight := time.Date(y, m, d, 0, 0, 0, 0, loc)
|
||||
return midnight.Add(t.Duration)
|
||||
}
|
||||
|
||||
type Window struct {
|
||||
Start TimeOfDay
|
||||
End TimeOfDay
|
||||
}
|
||||
|
||||
func (w Window) Contains(now time.Time, loc *time.Location) bool {
|
||||
start := w.Start.On(now, loc)
|
||||
end := w.End.On(now, loc)
|
||||
return !now.Before(start) && now.Before(end)
|
||||
}
|
||||
|
||||
func Drift(local, server time.Time) time.Duration {
|
||||
d := local.Sub(server)
|
||||
if d < 0 {
|
||||
return -d
|
||||
}
|
||||
return d
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package tinvest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
type Gateway interface {
|
||||
GetInstrument(ctx context.Context, ticker, classCode string) (domain.Instrument, error)
|
||||
GetCandles(ctx context.Context, instrumentUID string, interval string, from, to time.Time) ([]domain.Candle, error)
|
||||
GetOrderBook(ctx context.Context, instrumentUID string, depth int32) (domain.OrderBook, error)
|
||||
GetTradingStatus(ctx context.Context, instrumentUID string) (domain.TradingStatus, error)
|
||||
PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error)
|
||||
CancelOrder(ctx context.Context, accountID, orderID string) error
|
||||
GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error)
|
||||
GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error)
|
||||
GetPortfolio(ctx context.Context, accountID string) (domain.Portfolio, error)
|
||||
GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error)
|
||||
GetServerTime(ctx context.Context) (time.Time, error)
|
||||
}
|
||||
|
||||
type FakeGateway struct {
|
||||
mu sync.Mutex
|
||||
Instruments map[string]domain.Instrument
|
||||
Candles map[string][]domain.Candle
|
||||
OrderBooks map[string]domain.OrderBook
|
||||
Statuses map[string]domain.TradingStatus
|
||||
Orders map[string]domain.Order
|
||||
Portfolio domain.Portfolio
|
||||
Operations []domain.Operation
|
||||
ServerTime time.Time
|
||||
}
|
||||
|
||||
func NewFakeGateway() *FakeGateway {
|
||||
return &FakeGateway{
|
||||
Instruments: make(map[string]domain.Instrument),
|
||||
Candles: make(map[string][]domain.Candle),
|
||||
OrderBooks: make(map[string]domain.OrderBook),
|
||||
Statuses: make(map[string]domain.TradingStatus),
|
||||
Orders: make(map[string]domain.Order),
|
||||
Portfolio: domain.Portfolio{
|
||||
Equity: decimal.NewFromInt(100_000),
|
||||
Cash: decimal.NewFromInt(100_000),
|
||||
CheckedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetInstrument(_ context.Context, ticker, classCode string) (domain.Instrument, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
for _, instrument := range f.Instruments {
|
||||
if instrument.Ticker == ticker && instrument.ClassCode == classCode {
|
||||
return instrument, nil
|
||||
}
|
||||
}
|
||||
return domain.Instrument{}, ErrNotFound
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetCandles(_ context.Context, instrumentUID string, _ string, from, to time.Time) ([]domain.Candle, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
var out []domain.Candle
|
||||
for _, candle := range f.Candles[instrumentUID] {
|
||||
if !candle.TradeDate.Before(from) && !candle.TradeDate.After(to) {
|
||||
out = append(out, candle)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetOrderBook(_ context.Context, instrumentUID string, _ int32) (domain.OrderBook, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
book, ok := f.OrderBooks[instrumentUID]
|
||||
if !ok {
|
||||
return domain.OrderBook{}, ErrNotFound
|
||||
}
|
||||
return book, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetTradingStatus(_ context.Context, instrumentUID string) (domain.TradingStatus, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
status, ok := f.Statuses[instrumentUID]
|
||||
if !ok {
|
||||
return domain.TradingStatusNormal, nil
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) PostLimitOrder(_ context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
order := domain.Order{
|
||||
ClientOrderID: clientOrderID,
|
||||
BrokerOrderID: "fake-" + clientOrderID,
|
||||
AccountIDHash: accountID,
|
||||
InstrumentUID: instrumentUID,
|
||||
Side: side,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: price,
|
||||
QuantityLots: lots,
|
||||
Status: domain.OrderStatusSent,
|
||||
RawStateJSON: "{}",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
f.Orders[order.BrokerOrderID] = order
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) CancelOrder(_ context.Context, _ string, orderID string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
order, ok := f.Orders[orderID]
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
order.Status = domain.OrderStatusCancelled
|
||||
order.UpdatedAt = time.Now().UTC()
|
||||
f.Orders[orderID] = order
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetOrderState(_ context.Context, _ string, orderID string) (domain.Order, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
order, ok := f.Orders[orderID]
|
||||
if !ok {
|
||||
return domain.Order{}, ErrNotFound
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetActiveOrders(_ context.Context, _ string) ([]domain.Order, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
out := make([]domain.Order, 0)
|
||||
for _, order := range f.Orders {
|
||||
if order.Status == domain.OrderStatusSent || order.Status == domain.OrderStatusPartiallyFilled {
|
||||
out = append(out, order)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetPortfolio(_ context.Context, _ string) (domain.Portfolio, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.Portfolio.CheckedAt = time.Now().UTC()
|
||||
return f.Portfolio, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetOperations(_ context.Context, _ string, from, to time.Time) ([]domain.Operation, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
var out []domain.Operation
|
||||
for _, op := range f.Operations {
|
||||
if !op.ExecutedAt.Before(from) && !op.ExecutedAt.After(to) {
|
||||
out = append(out, op)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (f *FakeGateway) GetServerTime(context.Context) (time.Time, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.ServerTime.IsZero() {
|
||||
return time.Now().UTC(), nil
|
||||
}
|
||||
return f.ServerTime, nil
|
||||
}
|
||||
@@ -0,0 +1,456 @@
|
||||
package tinvest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/russianinvestments/invest-api-go-sdk/investgo"
|
||||
pb "github.com/russianinvestments/invest-api-go-sdk/proto"
|
||||
"github.com/shopspring/decimal"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/logging"
|
||||
"overnight-trading-bot/internal/money"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Token string
|
||||
AccountID string
|
||||
Endpoint string
|
||||
AppName string
|
||||
RetryCount int
|
||||
RetryBackoff time.Duration
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
type RealGateway struct {
|
||||
client *investgo.Client
|
||||
instruments *investgo.InstrumentsServiceClient
|
||||
marketData *investgo.MarketDataServiceClient
|
||||
orders *investgo.OrdersServiceClient
|
||||
operations *investgo.OperationsServiceClient
|
||||
users *investgo.UsersServiceClient
|
||||
retryAttempts int
|
||||
retryBackoff time.Duration
|
||||
}
|
||||
|
||||
func NewRealGateway(ctx context.Context, opts Options) (*RealGateway, error) {
|
||||
if opts.Token == "" {
|
||||
return nil, fmt.Errorf("tinvest token is required")
|
||||
}
|
||||
client, err := investgo.NewClient(ctx, investgo.Config{
|
||||
EndPoint: opts.Endpoint,
|
||||
Token: opts.Token,
|
||||
AppName: opts.AppName,
|
||||
AccountId: opts.AccountID,
|
||||
MaxRetries: 0,
|
||||
}, logging.SDKLogger{Logger: opts.Logger})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RealGateway{
|
||||
client: client,
|
||||
instruments: client.NewInstrumentsServiceClient(),
|
||||
marketData: client.NewMarketDataServiceClient(),
|
||||
orders: client.NewOrdersServiceClient(),
|
||||
operations: client.NewOperationsServiceClient(),
|
||||
users: client.NewUsersServiceClient(),
|
||||
retryAttempts: opts.RetryCount,
|
||||
retryBackoff: opts.RetryBackoff,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) Close() error {
|
||||
if g.client == nil || g.client.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
return g.client.Conn.Close()
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetInstrument(ctx context.Context, ticker, classCode string) (domain.Instrument, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return domain.Instrument{}, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.EtfResponse, error) {
|
||||
return g.instruments.EtfByTicker(ticker, classCode)
|
||||
})
|
||||
if err != nil {
|
||||
return domain.Instrument{}, err
|
||||
}
|
||||
etf := resp.GetInstrument()
|
||||
if etf == nil {
|
||||
return domain.Instrument{}, ErrNotFound
|
||||
}
|
||||
return domain.Instrument{
|
||||
InstrumentUID: etf.GetUid(),
|
||||
Figi: etf.GetFigi(),
|
||||
Ticker: etf.GetTicker(),
|
||||
ClassCode: etf.GetClassCode(),
|
||||
Name: etf.GetName(),
|
||||
Lot: int64(etf.GetLot()),
|
||||
MinPriceIncrement: money.QuotationToDecimal(etf.GetMinPriceIncrement()),
|
||||
Currency: strings.ToUpper(etf.GetCurrency()),
|
||||
Enabled: etf.GetApiTradeAvailableFlag() && etf.GetBuyAvailableFlag() && etf.GetSellAvailableFlag(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetCandles(ctx context.Context, instrumentUID string, interval string, from, to time.Time) ([]domain.Candle, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetCandlesResponse, error) {
|
||||
return g.marketData.GetCandles(instrumentUID, candleInterval(interval), from, to, pb.GetCandlesRequest_CANDLE_SOURCE_EXCHANGE, 0)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
candles := resp.GetCandles()
|
||||
out := make([]domain.Candle, 0, len(candles))
|
||||
for _, candle := range candles {
|
||||
out = append(out, domain.Candle{
|
||||
InstrumentUID: instrumentUID,
|
||||
TradeDate: candle.GetTime().AsTime().UTC(),
|
||||
Open: money.QuotationToDecimal(candle.GetOpen()),
|
||||
High: money.QuotationToDecimal(candle.GetHigh()),
|
||||
Low: money.QuotationToDecimal(candle.GetLow()),
|
||||
Close: money.QuotationToDecimal(candle.GetClose()),
|
||||
VolumeLots: decimal.NewFromInt(candle.GetVolume()),
|
||||
Source: "tinvest",
|
||||
LoadedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetOrderBook(ctx context.Context, instrumentUID string, depth int32) (domain.OrderBook, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return domain.OrderBook{}, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderBookResponse, error) {
|
||||
return g.marketData.GetOrderBook(instrumentUID, depth)
|
||||
})
|
||||
if err != nil {
|
||||
return domain.OrderBook{}, err
|
||||
}
|
||||
return domain.OrderBook{
|
||||
InstrumentUID: instrumentUID,
|
||||
Bids: orderBookLevels(resp.GetBids()),
|
||||
Asks: orderBookLevels(resp.GetAsks()),
|
||||
Time: resp.GetOrderbookTs().AsTime().UTC(),
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetTradingStatus(ctx context.Context, instrumentUID string) (domain.TradingStatus, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return domain.TradingStatusUnknown, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetTradingStatusResponse, error) {
|
||||
return g.marketData.GetTradingStatus(instrumentUID)
|
||||
})
|
||||
if err != nil {
|
||||
return domain.TradingStatusUnknown, err
|
||||
}
|
||||
if resp.GetTradingStatus() == pb.SecurityTradingStatus_SECURITY_TRADING_STATUS_NORMAL_TRADING &&
|
||||
resp.GetLimitOrderAvailableFlag() &&
|
||||
resp.GetApiTradeAvailableFlag() {
|
||||
return domain.TradingStatusNormal, nil
|
||||
}
|
||||
return domain.TradingStatusClosed, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
direction := pb.OrderDirection_ORDER_DIRECTION_BUY
|
||||
if side == domain.SideSell {
|
||||
direction = pb.OrderDirection_ORDER_DIRECTION_SELL
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) {
|
||||
return g.orders.PostOrder(&investgo.PostOrderRequest{
|
||||
InstrumentId: instrumentUID,
|
||||
Quantity: lots,
|
||||
Price: money.DecimalToQuotation(price),
|
||||
Direction: direction,
|
||||
AccountId: accountID,
|
||||
OrderType: pb.OrderType_ORDER_TYPE_LIMIT,
|
||||
OrderId: clientOrderID,
|
||||
TimeInForce: pb.TimeInForceType_TIME_IN_FORCE_DAY,
|
||||
PriceType: pb.PriceType_PRICE_TYPE_CURRENCY,
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
return orderFromPostResponse(resp.PostOrderResponse, accountID, clientOrderID, side, price), nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) CancelOrder(ctx context.Context, accountID, orderID string) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error {
|
||||
_, err := g.orders.CancelOrder(accountID, orderID, nil)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) {
|
||||
return g.orders.GetOrderState(accountID, orderID, pb.PriceType_PRICE_TYPE_CURRENCY, nil)
|
||||
})
|
||||
if err != nil {
|
||||
return domain.Order{}, err
|
||||
}
|
||||
return orderFromState(resp.OrderState, accountID), nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) {
|
||||
return g.orders.GetOrders(accountID, nil)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
states := resp.GetOrders()
|
||||
out := make([]domain.Order, 0, len(states))
|
||||
for _, state := range states {
|
||||
out = append(out, orderFromState(state, accountID))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetPortfolio(ctx context.Context, accountID string) (domain.Portfolio, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return domain.Portfolio{}, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) {
|
||||
return g.operations.GetPortfolio(accountID, pb.PortfolioRequest_RUB)
|
||||
})
|
||||
if err != nil {
|
||||
return domain.Portfolio{}, err
|
||||
}
|
||||
positions := resp.GetPositions()
|
||||
holdings := make([]domain.Holding, 0, len(positions))
|
||||
for _, position := range positions {
|
||||
holdings = append(holdings, domain.Holding{
|
||||
InstrumentUID: position.GetInstrumentUid(),
|
||||
QuantityLots: money.QuotationToDecimal(position.GetQuantity()).IntPart(),
|
||||
AveragePrice: money.MoneyValueToDecimal(position.GetAveragePositionPrice()),
|
||||
MarketValue: money.MoneyValueToDecimal(position.GetCurrentPrice()).Mul(money.QuotationToDecimal(position.GetQuantity())),
|
||||
})
|
||||
}
|
||||
equity, err := rubMoneyValueToDecimal(resp.GetTotalAmountPortfolio())
|
||||
if err != nil {
|
||||
return domain.Portfolio{}, err
|
||||
}
|
||||
cash, err := rubMoneyValueToDecimal(resp.GetTotalAmountCurrencies())
|
||||
if err != nil {
|
||||
return domain.Portfolio{}, err
|
||||
}
|
||||
return domain.Portfolio{
|
||||
Equity: equity,
|
||||
Cash: cash,
|
||||
Holdings: holdings,
|
||||
CheckedAt: time.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) {
|
||||
return g.operations.GetOperations(&investgo.GetOperationsRequest{
|
||||
AccountId: accountID,
|
||||
From: from,
|
||||
To: to,
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ops := resp.GetOperations()
|
||||
out := make([]domain.Operation, 0, len(ops))
|
||||
for _, op := range ops {
|
||||
payment := money.MoneyValueToDecimal(op.GetPayment())
|
||||
out = append(out, domain.Operation{
|
||||
ID: op.GetId(),
|
||||
InstrumentUID: op.GetInstrumentUid(),
|
||||
Type: op.GetOperationType().String(),
|
||||
Payment: payment,
|
||||
Commission: operationCommission(op.GetOperationType(), payment),
|
||||
ExecutedAt: op.GetDate().AsTime().UTC(),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (g *RealGateway) GetServerTime(ctx context.Context) (time.Time, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetInfoResponse, error) {
|
||||
return g.users.GetInfo()
|
||||
})
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if serverTime, ok := serverTimeFromHeader(resp.Header); ok {
|
||||
return serverTime, nil
|
||||
}
|
||||
return time.Time{}, errors.New("server time is unavailable in response metadata")
|
||||
}
|
||||
|
||||
func operationCommission(operationType pb.OperationType, payment decimal.Decimal) decimal.Decimal {
|
||||
if operationType != pb.OperationType_OPERATION_TYPE_BROKER_FEE &&
|
||||
operationType != pb.OperationType_OPERATION_TYPE_SERVICE_FEE &&
|
||||
operationType != pb.OperationType_OPERATION_TYPE_SUCCESS_FEE {
|
||||
return decimal.Zero
|
||||
}
|
||||
return money.Abs(payment)
|
||||
}
|
||||
|
||||
func rubMoneyValueToDecimal(value *pb.MoneyValue) (decimal.Decimal, error) {
|
||||
if value == nil {
|
||||
return decimal.Zero, nil
|
||||
}
|
||||
if currency := strings.ToUpper(value.GetCurrency()); currency != "" && currency != "RUB" {
|
||||
return decimal.Zero, fmt.Errorf("expected RUB money value, got %s", currency)
|
||||
}
|
||||
return money.MoneyValueToDecimal(value), nil
|
||||
}
|
||||
|
||||
func serverTimeFromHeader(header map[string][]string) (time.Time, bool) {
|
||||
for _, key := range []string{"date", "Date"} {
|
||||
values := header[key]
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
parsed, err := http.ParseTime(values[0])
|
||||
if err == nil {
|
||||
return parsed.UTC(), true
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func candleInterval(interval string) pb.CandleInterval {
|
||||
switch strings.ToLower(interval) {
|
||||
case "minute", "1m", "1min":
|
||||
return pb.CandleInterval_CANDLE_INTERVAL_1_MIN
|
||||
default:
|
||||
return pb.CandleInterval_CANDLE_INTERVAL_DAY
|
||||
}
|
||||
}
|
||||
|
||||
func orderBookLevels(levels []*pb.Order) []domain.OrderBookLevel {
|
||||
out := make([]domain.OrderBookLevel, 0, len(levels))
|
||||
for _, level := range levels {
|
||||
out = append(out, domain.OrderBookLevel{
|
||||
Price: money.QuotationToDecimal(level.GetPrice()),
|
||||
QuantityLots: level.GetQuantity(),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID string, side domain.Side, limitPrice decimal.Decimal) domain.Order {
|
||||
if resp == nil {
|
||||
return domain.Order{}
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
return domain.Order{
|
||||
ClientOrderID: clientOrderID,
|
||||
BrokerOrderID: resp.GetOrderId(),
|
||||
AccountIDHash: accountID,
|
||||
InstrumentUID: resp.GetInstrumentUid(),
|
||||
Side: side,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: limitPrice,
|
||||
QuantityLots: resp.GetLotsRequested(),
|
||||
FilledLots: resp.GetLotsExecuted(),
|
||||
AvgFillPrice: limitPrice,
|
||||
Status: mapOrderStatus(resp.GetExecutionReportStatus()),
|
||||
Commission: money.MoneyValueToDecimal(resp.GetExecutedCommission()),
|
||||
RawStateJSON: marshalProto(resp),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func orderFromState(state *pb.OrderState, accountID string) domain.Order {
|
||||
if state == nil {
|
||||
return domain.Order{}
|
||||
}
|
||||
side := domain.SideBuy
|
||||
if state.GetDirection() == pb.OrderDirection_ORDER_DIRECTION_SELL {
|
||||
side = domain.SideSell
|
||||
}
|
||||
orderDate := time.Now().UTC()
|
||||
if state.GetOrderDate() != nil {
|
||||
orderDate = state.GetOrderDate().AsTime().UTC()
|
||||
}
|
||||
return domain.Order{
|
||||
ClientOrderID: state.GetOrderRequestId(),
|
||||
BrokerOrderID: state.GetOrderId(),
|
||||
AccountIDHash: accountID,
|
||||
InstrumentUID: state.GetInstrumentUid(),
|
||||
Side: side,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: money.MoneyValueToDecimal(state.GetInitialSecurityPrice()),
|
||||
QuantityLots: state.GetLotsRequested(),
|
||||
FilledLots: state.GetLotsExecuted(),
|
||||
AvgFillPrice: money.MoneyValueToDecimal(state.GetAveragePositionPrice()),
|
||||
Status: mapOrderStatus(state.GetExecutionReportStatus()),
|
||||
Commission: money.MoneyValueToDecimal(state.GetExecutedCommission()),
|
||||
RawStateJSON: marshalProto(state),
|
||||
CreatedAt: orderDate,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func mapOrderStatus(status pb.OrderExecutionReportStatus) domain.OrderStatus {
|
||||
switch status {
|
||||
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_FILL:
|
||||
return domain.OrderStatusFilled
|
||||
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_PARTIALLYFILL:
|
||||
return domain.OrderStatusPartiallyFilled
|
||||
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_CANCELLED:
|
||||
return domain.OrderStatusCancelled
|
||||
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_REJECTED:
|
||||
return domain.OrderStatusRejected
|
||||
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_NEW:
|
||||
return domain.OrderStatusSent
|
||||
default:
|
||||
return domain.OrderStatusNew
|
||||
}
|
||||
}
|
||||
|
||||
func marshalProto(msg proto.Message) string {
|
||||
if msg == nil {
|
||||
return "{}"
|
||||
}
|
||||
raw, err := protojson.Marshal(msg)
|
||||
if err != nil {
|
||||
fallback, _ := json.Marshal(map[string]string{"marshal_error": err.Error()})
|
||||
return string(fallback)
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package tinvest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
backofflib "github.com/cenkalti/backoff/v4"
|
||||
)
|
||||
|
||||
func withRetry(ctx context.Context, attempts int, interval time.Duration, fn func() error) error {
|
||||
if attempts <= 0 {
|
||||
attempts = 1
|
||||
}
|
||||
if interval < 0 {
|
||||
interval = 0
|
||||
}
|
||||
policy := backofflib.NewExponentialBackOff()
|
||||
policy.InitialInterval = interval
|
||||
policy.MaxInterval = interval * 8
|
||||
policy.Multiplier = 2
|
||||
policy.MaxElapsedTime = 0
|
||||
policy.Reset()
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
if attempt == attempts-1 || interval <= 0 {
|
||||
continue
|
||||
}
|
||||
timer := time.NewTimer(policy.NextBackOff())
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func retryValue[T any](ctx context.Context, attempts int, interval time.Duration, fn func() (T, error)) (T, error) {
|
||||
var out T
|
||||
err := withRetry(ctx, attempts, interval, func() error {
|
||||
var err error
|
||||
out, err = fn()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package tinvest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWithRetryRetriesUntilSuccess(t *testing.T) {
|
||||
attempts := 0
|
||||
err := withRetry(context.Background(), 3, 0, func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errors.New("temporary")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if attempts != 3 {
|
||||
t.Fatalf("attempts=%d, want 3", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithRetryStopsOnContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
attempts := 0
|
||||
err := withRetry(ctx, 3, time.Millisecond, func() error {
|
||||
attempts++
|
||||
return errors.New("temporary")
|
||||
})
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("err=%v, want context.Canceled", err)
|
||||
}
|
||||
if attempts != 0 {
|
||||
t.Fatalf("attempts=%d, want 0", attempts)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package tinvest
|
||||
|
||||
import "context"
|
||||
|
||||
const sandboxEndpoint = "sandbox-invest-public-api.tinkoff.ru:443"
|
||||
|
||||
func NewSandboxGateway(ctx context.Context, opts Options) (*RealGateway, error) {
|
||||
opts.Endpoint = sandboxEndpoint
|
||||
return NewRealGateway(ctx, opts)
|
||||
}
|
||||
Reference in New Issue
Block a user