diff --git a/README.md b/README.md index 32d0a5d..ef65b2b 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ APP_MODE=backtest go run ./cmd/bot | Переменная | Что указывать | Дефолт | Границы/валидация | За что отвечает и что меняется | | --- | --- | --- | --- | --- | -| `APP_MODE` | `backtest`, `paper`, `sandbox`, `live_readonly`, `live_trade` | нет, в `.env.example`: `paper` | обязательна; только перечисленные значения | Режим работы. `backtest` не требует БД и API в `cmd/bot`; `paper` использует fake gateway; `sandbox`, `live_readonly`, `live_trade` подключаются к T-Invest API; `live_trade` может отправлять брокерские заявки. | +| `APP_MODE` | `backtest`, `paper`, `sandbox`, `live_readonly`, `live_trade` | нет, в `.env.example`: `paper` | обязательна; только перечисленные значения | Режим работы. `backtest` не требует БД и API в `cmd/bot`; `paper` без `TINVEST_TOKEN` использует fake gateway, а с токеном берёт реальные market data/status через T-Invest при симулированных заявках; `sandbox`, `live_readonly`, `live_trade` подключаются к T-Invest API; `live_trade` может отправлять брокерские заявки. | | `APP_TIMEZONE` | `Europe/Moscow` | `Europe/Moscow` | жёстко только `Europe/Moscow` | Таймзона расписания торговых окон. Изменить нельзя без изменения валидации. | | `APP_LOG_LEVEL` | `debug`, `info`, `warn`, `warning`, `error` | `info` | неизвестное значение трактуется как `info` | Уровень JSON-логов. Ниже уровень - больше диагностических записей. | | `APP_HEALTHCHECK_ADDR` | HTTP listen address, например `:3300` или `127.0.0.1:3300` | `:3300` | без отдельной валидации | Адрес `/health` и `/ready`. При изменении меняется порт или интерфейс healthcheck-сервера. | @@ -39,11 +39,11 @@ APP_MODE=backtest go run ./cmd/bot | Переменная | Что указывать | Дефолт | Границы/валидация | За что отвечает и что меняется | | --- | --- | --- | --- | --- | -| `TINVEST_TOKEN` | токен T-Invest API | пусто | обязателен для `sandbox`, `live_readonly`, `live_trade` | Доступ к реальному или sandbox API. В `paper` и `backtest` не нужен. | +| `TINVEST_TOKEN` | токен T-Invest API | пусто | обязателен для `sandbox`, `live_readonly`, `live_trade`; опционален для `paper` | Доступ к реальному или sandbox API. В `paper` без токена используется fake gateway, с токеном - реальные market data и симулированные заявки. | | `TINVEST_ACCOUNT_ID` | идентификатор брокерского счёта | пусто | обязателен для `sandbox`, `live_readonly`, `live_trade` | Счёт для портфеля, заявок и сверки. Для API-режимов бот падает на старте, если account id не указан. | | `TINVEST_ENDPOINT` | gRPC endpoint T-Invest, обычно `host:port` | `invest-public-api.tinkoff.ru:443` | строка; валидации формата нет | Endpoint для API. В `sandbox` код принудительно использует sandbox endpoint. | | `TINVEST_APP_NAME` | имя приложения | `overnight-trading-bot` | строка | Передаётся в SDK как имя клиента. Меняет идентификацию приложения на стороне API/логов. | -| `TINVEST_REQUEST_TIMEOUT_SEC` | целое число секунд | `10` | рекомендуется `> 0`; сейчас не применяется | Зарезервировано под таймаут API-запросов. На текущий код не влияет. | +| `TINVEST_REQUEST_TIMEOUT_SEC` | целое число секунд | `10` | должно быть `> 0` | Таймаут API-запросов к T-Invest, включая retry-последовательность. Меньше значение быстрее освобождает торговый цикл при зависшем API, но повышает шанс timeout на медленной сети. | | `TINVEST_RETRY_COUNT` | целое число попыток | `3` | `<= 0` трактуется как одна попытка | Общее число попыток для SDK-вызовов T-Invest через exponential backoff. Больше значение повышает устойчивость к кратким сбоям, но может дольше задерживать окончательную ошибку. | | `TINVEST_RETRY_BACKOFF_SEC` | целое число секунд | `2` | рекомендуется `>= 0` | Начальный интервал exponential backoff для SDK-вызовов T-Invest. Больше значение снижает частоту повторов при сбоях, но дольше задерживает окончательную ошибку. | | `TINVEST_USE_SANDBOX` | `true` или `false` | `false` | boolean; разрешено только при `APP_MODE=sandbox` | Защитный флаг совместимости. В `live_readonly` и `live_trade` запрещён валидацией, чтобы случайно не подменить фактическую среду исполнения. | @@ -149,6 +149,8 @@ APP_MODE=backtest go run ./cmd/bot | `COMM_QUARANTINE_ON_NONZERO` | `true` или `false` | `true` | boolean | При фактической брокерской комиссии `> 0` инструмент переводится в quarantine, а система останавливается через HALT по zero-commission policy. | | `COMM_FREE_ORDER_COUNT_POLICY` | `submitted` или `cancel_counts` | `submitted` | одно из двух значений | Политика учёта бесплатных заявок: `submitted` считает только отправку новой заявки, `cancel_counts` дополнительно считает успешные отмены перед repost. | +В справочнике инструментов `free_order_limit_per_day=0` означает, что политика бесплатных заявок не настроена и новые входы запрещены; `-1` означает явно подтверждённое отсутствие дневного лимита. + ### BT | Переменная | Что указывать | Дефолт | Границы/валидация | За что отвечает и что меняется | @@ -181,6 +183,7 @@ go run ./cmd/migrate up go run ./cmd/backtest -candles candles.csv -out ./backtest_out go run ./cmd/backtest -candles candles.csv -minute-candles minute.csv -use-minute-model -out ./backtest_out go run ./cmd/bot -mode=paper +go run ./cmd/bot -halt -reason="manual kill switch" go run ./cmd/bot -unhalt -reason="manual reconciliation complete" go run ./cmd/bot -healthcheck ``` diff --git a/cmd/bot/main.go b/cmd/bot/main.go index 2c51d16..3ce0884 100644 --- a/cmd/bot/main.go +++ b/cmd/bot/main.go @@ -11,8 +11,9 @@ import ( func main() { mode := flag.String("mode", "", "override APP_MODE: backtest|paper|sandbox|live_readonly|live_trade") + halt := flag.Bool("halt", false, "manually set HALT and stop new automated actions") unhalt := flag.Bool("unhalt", false, "manually clear HALT after reconciliation") - reason := flag.String("reason", "", "audit reason for -unhalt") + reason := flag.String("reason", "", "audit reason for -halt or -unhalt") health := flag.Bool("healthcheck", false, "check local /health endpoint") healthURL := flag.String("healthcheck-url", "", "healthcheck URL; default http://127.0.0.1:3300/health") flag.Parse() @@ -21,6 +22,7 @@ func main() { Stdout: os.Stdout, Stderr: os.Stderr, ModeOverride: *mode, + Halt: *halt, Unhalt: *unhalt, Reason: *reason, Healthcheck: *health, diff --git a/internal/app/app.go b/internal/app/app.go index e1499b4..dcfeb78 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -43,6 +43,7 @@ type Options struct { Stdout io.Writer Stderr io.Writer ModeOverride string + Halt bool Unhalt bool Reason string Healthcheck bool @@ -84,10 +85,16 @@ func Run(ctx context.Context, opts Options) error { log := logging.New(cfg.App.LogLevel, opts.Stdout) log.Info("overnight trading bot starting", "mode", cfg.App.Mode) - if cfg.App.Mode == domain.ModeBacktest && !opts.Unhalt { + if opts.Halt && opts.Unhalt { + return errors.New("-halt and -unhalt are mutually exclusive") + } + if cfg.App.Mode == domain.ModeBacktest && !opts.Unhalt && !opts.Halt { _, _ = fmt.Fprintf(opts.Stdout, "overnight trading bot initialized in %s mode\n", cfg.App.Mode) return nil } + if opts.Halt && cfg.DB.DSN == "" { + return errors.New("-halt requires DB_DSN") + } db, err := openDB(ctx, cfg) if err != nil { @@ -102,6 +109,16 @@ func Run(ctx context.Context, opts Options) error { } } repo := mysqlrepo.NewRepository(db) + if opts.Halt { + if strings.TrimSpace(opts.Reason) == "" { + return errors.New("-halt requires -reason") + } + if err := risk.NewManager(repo, risk.ManagerConfig{}).Halt(ctx, cfg.App.Mode, "manual_halt", opts.Reason, ""); err != nil { + return err + } + _, _ = fmt.Fprintf(opts.Stdout, "system halted: %s\n", opts.Reason) + return nil + } if opts.Unhalt { if strings.TrimSpace(opts.Reason) == "" { return errors.New("-unhalt requires -reason") @@ -342,15 +359,36 @@ func openDB(ctx context.Context, cfg config.Config) (*sqlx.DB, error) { func buildGateway(ctx context.Context, cfg config.Config, log *slog.Logger) (tinvest.Gateway, func(), error) { switch cfg.App.Mode { case domain.ModePaper: + if cfg.TInvest.Token != "" { + accountID := cfg.TInvest.AccountID + if accountID == "" { + accountID = "paper-readonly" + } + gw, err := tinvest.NewRealGateway(ctx, tinvest.Options{ + Token: cfg.TInvest.Token, + AccountID: accountID, + Endpoint: cfg.TInvest.Endpoint, + AppName: cfg.TInvest.AppName, + RequestTimeout: time.Duration(cfg.TInvest.RequestTimeoutSec) * time.Second, + RetryCount: cfg.TInvest.RetryCount, + RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second, + Logger: log, + }) + if err != nil { + return nil, nil, err + } + return tinvest.NewPaperGateway(gw), func() { _ = gw.Close() }, nil + } return tinvest.NewFakeGateway(), nil, nil case domain.ModeSandbox: gw, err := tinvest.NewSandboxGateway(ctx, tinvest.Options{ - Token: cfg.TInvest.Token, - AccountID: cfg.TInvest.AccountID, - AppName: cfg.TInvest.AppName, - RetryCount: cfg.TInvest.RetryCount, - RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second, - Logger: log, + Token: cfg.TInvest.Token, + AccountID: cfg.TInvest.AccountID, + AppName: cfg.TInvest.AppName, + RequestTimeout: time.Duration(cfg.TInvest.RequestTimeoutSec) * time.Second, + RetryCount: cfg.TInvest.RetryCount, + RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second, + Logger: log, }) if err != nil { return nil, nil, err @@ -362,13 +400,14 @@ func buildGateway(ctx context.Context, cfg config.Config, log *slog.Logger) (tin return nil, nil, errors.New("TINVEST_USE_SANDBOX is only allowed with APP_MODE=sandbox") } gw, err := tinvest.NewRealGateway(ctx, tinvest.Options{ - Token: cfg.TInvest.Token, - AccountID: cfg.TInvest.AccountID, - Endpoint: endpoint, - AppName: cfg.TInvest.AppName, - RetryCount: cfg.TInvest.RetryCount, - RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second, - Logger: log, + Token: cfg.TInvest.Token, + AccountID: cfg.TInvest.AccountID, + Endpoint: endpoint, + AppName: cfg.TInvest.AppName, + RequestTimeout: time.Duration(cfg.TInvest.RequestTimeoutSec) * time.Second, + RetryCount: cfg.TInvest.RetryCount, + RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second, + Logger: log, }) if err != nil { return nil, nil, err @@ -384,7 +423,11 @@ func seedPaperGateway(ctx context.Context, repo interface { }, gateway tinvest.Gateway) error { fake, ok := gateway.(*tinvest.FakeGateway) if !ok { - return nil + paper, isPaper := gateway.(*tinvest.PaperGateway) + if !isPaper { + return nil + } + fake = paper.Fake() } instrumentsList, err := repo.ListInstruments(ctx, true) if err != nil { diff --git a/internal/backtest/engine.go b/internal/backtest/engine.go index 4c16835..e757e41 100644 --- a/internal/backtest/engine.go +++ b/internal/backtest/engine.go @@ -448,6 +448,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, rawEdge := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000)) spreadBps := e.assumedSpreadBps(instrumentUID) cost := spreadBps. + Add(spreadBps). Add(e.cfg.EntrySlippageBps). Add(e.cfg.ExitSlippageBps). Add(e.cfg.CommissionRoundtripBps). diff --git a/internal/config/config.go b/internal/config/config.go index 62679a7..3a5dd56 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -177,6 +177,9 @@ func (c *Config) Validate() error { if c.App.ShutdownTimeoutSec <= 0 { return errors.New("APP_SHUTDOWN_TIMEOUT_SEC must be positive") } + if c.TInvest.RequestTimeoutSec <= 0 { + return errors.New("TINVEST_REQUEST_TIMEOUT_SEC must be positive") + } if c.Execution.AllowMarketOrders { return errors.New("EXEC_ALLOW_MARKET_ORDERS must remain false: strategy is LIMIT-only") } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 051c6a7..043a9d8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -35,8 +35,9 @@ func minimalBrokerConfig(mode domain.Mode) Config { ShutdownTimeoutSec: 30, }, TInvest: TInvestConfig{ - Token: "token", - AccountID: "account", + Token: "token", + AccountID: "account", + RequestTimeoutSec: 10, }, DB: DBConfig{DSN: "user:pass@tcp(localhost:3306)/bot"}, Execution: ExecutionConfig{ diff --git a/internal/execution/engine.go b/internal/execution/engine.go index 476e3e0..70054f8 100644 --- a/internal/execution/engine.go +++ b/internal/execution/engine.go @@ -507,9 +507,12 @@ func (e *Engine) repostDue(order domain.Order, after time.Duration) bool { } func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, instrument domain.Instrument) error { - if e.store == nil || instrument.FreeOrderLimitPerDay <= 0 { + if e.store == nil || instrument.FreeOrderLimitPerDay < 0 { return nil } + if instrument.FreeOrderLimitPerDay == 0 { + return risk.ErrFreeOrderPolicyUnspecified + } sent, err := e.store.GetFreeOrdersSent(ctx, order.TradeDate, instrument.InstrumentUID) if err != nil { return err diff --git a/internal/execution/state_test.go b/internal/execution/state_test.go index 3c18e0c..2eae243 100644 --- a/internal/execution/state_test.go +++ b/internal/execution/state_test.go @@ -249,9 +249,10 @@ func TestMonitorUntilRepostsAndExpiresAtDeadline(t *testing.T) { gateway := tinvest.NewFakeGateway() engine := NewEngine(domain.ModeSandbox, "account", gateway, repo) instrument := domain.Instrument{ - InstrumentUID: "uid", - Lot: 1, - MinPriceIncrement: decimal.NewFromInt(1), + InstrumentUID: "uid", + Lot: 1, + MinPriceIncrement: decimal.NewFromInt(1), + FreeOrderLimitPerDay: -1, } book := domain.OrderBook{ InstrumentUID: "uid", @@ -300,9 +301,10 @@ func TestMonitorOnceDoesNotRepostWhenCheckRejects(t *testing.T) { gateway := tinvest.NewFakeGateway() engine := NewEngine(domain.ModeSandbox, "account", gateway, repo) instrument := domain.Instrument{ - InstrumentUID: "uid", - Lot: 1, - MinPriceIncrement: decimal.NewFromInt(1), + InstrumentUID: "uid", + Lot: 1, + MinPriceIncrement: decimal.NewFromInt(1), + FreeOrderLimitPerDay: -1, } book := domain.OrderBook{ InstrumentUID: "uid", diff --git a/internal/features/pipeline.go b/internal/features/pipeline.go index a6442ca..b6d2553 100644 --- a/internal/features/pipeline.go +++ b/internal/features/pipeline.go @@ -114,6 +114,7 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000)) commission := roundTripCommissionBps(instrument, cfg) expectedCost := spread.SpreadBps. + Add(spread.SpreadBps). Add(cfg.EntrySlippageBps). Add(cfg.ExitSlippageBps). Add(commission). diff --git a/internal/features/pipeline_test.go b/internal/features/pipeline_test.go index 178fadd..1b9d2bf 100644 --- a/internal/features/pipeline_test.go +++ b/internal/features/pipeline_test.go @@ -41,8 +41,8 @@ func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) { if err != nil { t.Fatal(err) } - if !got.ExpectedCostBps.Equal(decimal.NewFromInt(22)) { - t.Fatalf("expected cost=%s, want 22", got.ExpectedCostBps) + if !got.ExpectedCostBps.Equal(decimal.NewFromInt(32)) { + t.Fatalf("expected cost=%s, want 32", got.ExpectedCostBps) } if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) { t.Fatalf("interval volumes were not preserved: %+v", got) @@ -66,8 +66,8 @@ func TestComputeExpectedCostFallsBackToConfigCommission(t *testing.T) { if err != nil { t.Fatal(err) } - if !got.ExpectedCostBps.Equal(decimal.NewFromInt(24)) { - t.Fatalf("expected cost=%s, want 24", got.ExpectedCostBps) + if !got.ExpectedCostBps.Equal(decimal.NewFromInt(34)) { + t.Fatalf("expected cost=%s, want 34", got.ExpectedCostBps) } } diff --git a/internal/instruments/registry.go b/internal/instruments/registry.go index 7b4f1e4..304bf54 100644 --- a/internal/instruments/registry.go +++ b/internal/instruments/registry.go @@ -24,9 +24,12 @@ func (r Registry) SyncMetadata(ctx context.Context) error { return err } for _, instrument := range instruments { + if !instrument.Enabled || instrument.Quarantine { + continue + } remote, err := r.gateway.GetInstrument(ctx, instrument.Ticker, instrument.ClassCode) if err != nil { - return fmt.Errorf("sync %s: %w", instrument.Ticker, err) + continue } remote.Enabled = instrument.Enabled && remote.Enabled remote.FundType = instrument.FundType diff --git a/internal/marketdata/loader.go b/internal/marketdata/loader.go index c884292..f98fa48 100644 --- a/internal/marketdata/loader.go +++ b/internal/marketdata/loader.go @@ -28,33 +28,55 @@ func (l *Loader) SetClock(clock timeutil.Clock) { } func (l Loader) BackfillDaily(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error { + eligible := 0 + succeeded := 0 + var firstErr error for _, instrument := range instruments { if !instrument.Enabled || instrument.Quarantine { continue } + eligible++ candles, err := l.gateway.GetCandles(ctx, instrument.InstrumentUID, "day", from, to) if err != nil { - return fmt.Errorf("load candles %s: %w", instrument.Ticker, err) + if firstErr == nil { + firstErr = fmt.Errorf("load candles %s: %w", instrument.Ticker, err) + } + continue } if err := l.repo.UpsertDailyCandles(ctx, candles); err != nil { return fmt.Errorf("persist candles %s: %w", instrument.Ticker, err) } + succeeded++ + } + if eligible > 0 && succeeded == 0 && firstErr != nil { + return fmt.Errorf("all daily candle loads failed: %w", firstErr) } return nil } func (l Loader) BackfillMinute(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error { + eligible := 0 + succeeded := 0 + var firstErr error for _, instrument := range instruments { if !instrument.Enabled || instrument.Quarantine { continue } + eligible++ candles, err := l.gateway.GetCandles(ctx, instrument.InstrumentUID, "minute", from, to) if err != nil { - return fmt.Errorf("load minute candles %s: %w", instrument.Ticker, err) + if firstErr == nil { + firstErr = fmt.Errorf("load minute candles %s: %w", instrument.Ticker, err) + } + continue } if err := l.repo.UpsertMinuteCandles(ctx, candles); err != nil { return fmt.Errorf("persist minute candles %s: %w", instrument.Ticker, err) } + succeeded++ + } + if eligible > 0 && succeeded == 0 && firstErr != nil { + return fmt.Errorf("all minute candle loads failed: %w", firstErr) } return nil } diff --git a/internal/position/manager.go b/internal/position/manager.go index 822e33a..6408666 100644 --- a/internal/position/manager.go +++ b/internal/position/manager.go @@ -25,13 +25,48 @@ func (m Manager) OnEntryFill(ctx context.Context, accountIDHash string, instrume if lot <= 0 { lot = 1 } + fillLots := order.FilledLots + if fillLots < 0 { + fillLots = 0 + } + fillPrice := order.AvgFillPrice + if !fillPrice.IsPositive() { + fillPrice = order.LimitPrice + } + if existing, ok, err := m.findEntryPosition(ctx, accountIDHash, order); err != nil { + return domain.Position{}, err + } else if ok { + previousLots := existing.Lots + totalLots := previousLots + fillLots + if fillLots > 0 && totalLots > 0 { + previousValue := existing.AvgBuyPrice.Mul(decimal.NewFromInt(previousLots)) + fillValue := fillPrice.Mul(decimal.NewFromInt(fillLots)) + existing.AvgBuyPrice = previousValue.Add(fillValue).Div(decimal.NewFromInt(totalLots)) + } + existing.Lots = totalLots + existing.Lot = lot + existing.CommissionTotal = existing.CommissionTotal.Add(order.Commission) + if existing.OpenedAt == nil { + existing.OpenedAt = &now + } + if order.FilledLots < order.QuantityLots { + existing.Status = domain.PositionEntryPartiallyFilled + } else if existing.Status != domain.PositionHoldingOvernight { + existing.Status = domain.PositionEntryFilled + } + existing.UpdatedAt = now + if err := m.repo.UpsertPosition(ctx, existing); err != nil { + return domain.Position{}, err + } + return existing, nil + } pos := domain.Position{ AccountIDHash: accountIDHash, InstrumentUID: order.InstrumentUID, OpenTradeDate: order.TradeDate, - Lots: order.FilledLots, + Lots: fillLots, Lot: lot, - AvgBuyPrice: order.AvgFillPrice, + AvgBuyPrice: fillPrice, CommissionTotal: order.Commission, Status: domain.PositionEntryFilled, OpenedAt: &now, @@ -46,6 +81,28 @@ func (m Manager) OnEntryFill(ctx context.Context, accountIDHash string, instrume return pos, nil } +func (m Manager) findEntryPosition(ctx context.Context, accountIDHash string, order domain.Order) (domain.Position, bool, error) { + positions, err := m.repo.ListPositions(ctx, accountIDHash, order.TradeDate, order.TradeDate) + if err != nil { + return domain.Position{}, false, err + } + for _, pos := range positions { + if pos.InstrumentUID != order.InstrumentUID { + continue + } + switch pos.Status { + case domain.PositionEntrySignalled, + domain.PositionEntryOrderSent, + domain.PositionEntryPartiallyFilled, + domain.PositionEntryFilled, + domain.PositionHoldingOvernight: + return pos, true, nil + default: + } + } + return domain.Position{}, false, nil +} + func (m Manager) OnExitFill(ctx context.Context, pos domain.Position, exitOrder domain.Order) (domain.Position, error) { now := time.Now().UTC() lot := pos.Lot diff --git a/internal/position/manager_test.go b/internal/position/manager_test.go index 1a89130..7efc5f2 100644 --- a/internal/position/manager_test.go +++ b/internal/position/manager_test.go @@ -33,6 +33,47 @@ func TestOnEntryFillKeepsBuyCommission(t *testing.T) { } } +func TestOnEntryFillAggregatesRepostedPartialFills(t *testing.T) { + ctx := context.Background() + manager := NewManager(testutil.NewMemoryRepository()) + tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) + first, err := manager.OnEntryFill(ctx, "hash", domain.Instrument{Lot: 1}, domain.Order{ + InstrumentUID: "uid", + TradeDate: tradeDate, + QuantityLots: 10, + FilledLots: 4, + AvgFillPrice: decimal.NewFromInt(100), + Commission: decimal.NewFromInt(1), + }) + if err != nil { + t.Fatal(err) + } + if first.Status != domain.PositionEntryPartiallyFilled || first.Lots != 4 { + t.Fatalf("first position=%+v, want partial 4 lots", first) + } + second, err := manager.OnEntryFill(ctx, "hash", domain.Instrument{Lot: 1}, domain.Order{ + InstrumentUID: "uid", + TradeDate: tradeDate, + QuantityLots: 6, + FilledLots: 6, + AvgFillPrice: decimal.NewFromInt(110), + Commission: decimal.NewFromInt(2), + }) + if err != nil { + t.Fatal(err) + } + wantAvg := decimal.NewFromInt(106) + if second.Lots != 10 || second.Status != domain.PositionEntryFilled { + t.Fatalf("aggregated position=%+v, want 10 lots ENTRY_FILLED", second) + } + if !second.AvgBuyPrice.Equal(wantAvg) { + t.Fatalf("avg buy=%s, want %s", second.AvgBuyPrice, wantAvg) + } + if !second.CommissionTotal.Equal(decimal.NewFromInt(3)) { + t.Fatalf("commission=%s, want 3", second.CommissionTotal) + } +} + func TestOnExitFillPartialUsesExecutedLots(t *testing.T) { ctx := context.Background() manager := NewManager(testutil.NewMemoryRepository()) diff --git a/internal/repository/migrations/0008_free_order_policy.down.sql b/internal/repository/migrations/0008_free_order_policy.down.sql new file mode 100644 index 0000000..e67c048 --- /dev/null +++ b/internal/repository/migrations/0008_free_order_policy.down.sql @@ -0,0 +1,4 @@ +ALTER TABLE instruments + MODIFY free_order_limit_per_day INT NOT NULL DEFAULT 0 COMMENT '0 means no configured free-order cap'; + +UPDATE schema_meta SET meta_value='0007' WHERE meta_key='schema_version'; diff --git a/internal/repository/migrations/0008_free_order_policy.up.sql b/internal/repository/migrations/0008_free_order_policy.up.sql new file mode 100644 index 0000000..cc606e0 --- /dev/null +++ b/internal/repository/migrations/0008_free_order_policy.up.sql @@ -0,0 +1,4 @@ +ALTER TABLE instruments + MODIFY free_order_limit_per_day INT NOT NULL DEFAULT 0 COMMENT '0 means free-order policy is unconfigured; -1 means explicitly no free-order cap'; + +UPDATE schema_meta SET meta_value='0008' WHERE meta_key='schema_version'; diff --git a/internal/risk/freeorders_test.go b/internal/risk/freeorders_test.go index c4f7204..732af6b 100644 --- a/internal/risk/freeorders_test.go +++ b/internal/risk/freeorders_test.go @@ -25,3 +25,15 @@ func TestFreeOrderBudgetSubmittedPolicy(t *testing.T) { t.Fatalf("expected ErrFreeOrderBudget, got %v", err) } } + +func TestFreeOrderBudgetRequiresExplicitPolicy(t *testing.T) { + ctx := context.Background() + budget := NewFreeOrderBudget(NewMemoryFreeOrderStore()) + date := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + if _, err := budget.Check(ctx, date, domain.Instrument{InstrumentUID: "uid"}, 1); !errors.Is(err, ErrFreeOrderPolicyUnspecified) { + t.Fatalf("expected ErrFreeOrderPolicyUnspecified, got %v", err) + } + if _, err := budget.Check(ctx, date, domain.Instrument{InstrumentUID: "uid", FreeOrderLimitPerDay: -1}, 1); err != nil { + t.Fatalf("explicit no-cap policy should pass, got %v", err) + } +} diff --git a/internal/risk/sizing.go b/internal/risk/sizing.go index 88e8479..c8068c1 100644 --- a/internal/risk/sizing.go +++ b/internal/risk/sizing.go @@ -13,8 +13,9 @@ import ( ) var ( - ErrNoSizingCapacity = errors.New("no sizing capacity") - ErrFreeOrderBudget = errors.New("free order budget is insufficient") + ErrNoSizingCapacity = errors.New("no sizing capacity") + ErrFreeOrderBudget = errors.New("free order budget is insufficient") + ErrFreeOrderPolicyUnspecified = errors.New("free order policy is not configured") ) type SizingConfig struct { @@ -131,9 +132,12 @@ func NewFreeOrderBudget(store FreeOrderStore) FreeOrderBudget { } func (b FreeOrderBudget) Check(ctx context.Context, tradeDate time.Time, instr domain.Instrument, ordersNeeded int) (int, error) { - if instr.FreeOrderLimitPerDay <= 0 { + if instr.FreeOrderLimitPerDay < 0 { return 0, nil } + if instr.FreeOrderLimitPerDay == 0 { + return 0, ErrFreeOrderPolicyUnspecified + } sent, err := b.store.GetFreeOrdersSent(ctx, tradeDate, instr.InstrumentUID) if err != nil { return 0, err diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index adb5b72..8072387 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -463,6 +463,12 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error { if err != nil { tradingStatus = domain.TradingStatusUnknown } + if err := s.checkEntryInstrumentBeforeOrder(instrument, tradingStatus); err != nil { + if insertErr := s.recordPreTradeReject(ctx, sig.InstrumentUID, err.Error(), `{"reason":"instrument_pre_trade"}`); insertErr != nil { + return insertErr + } + continue + } portfolio, err = s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID) if err != nil { return err @@ -1153,6 +1159,12 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order if err != nil { tradingStatus = domain.TradingStatusUnknown } + if order.Side == domain.SideBuy { + if err := s.checkEntryInstrumentBeforeOrder(instrument, tradingStatus); err != nil { + _ = s.recordPreTradeReject(ctx, order.InstrumentUID, err.Error(), `{"reason":"instrument_pre_trade","stage":"repost"}`) + return err + } + } portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID) if err != nil { return err @@ -1172,6 +1184,16 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order return nil } +func (s Scheduler) checkEntryInstrumentBeforeOrder(instrument domain.Instrument, tradingStatus domain.TradingStatus) error { + if err := instruments.CheckInstrument(instrument, tradingStatus); err != nil { + return err + } + if s.cfg.RequireZeroCommission && instrument.ExpectedCommissionBpsPerSide.IsPositive() { + return errors.New(signal.ReasonCommission) + } + return nil +} + func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, instrumentUID string, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) { metrics, err := s.riskMetrics(ctx, now, portfolio) if err != nil { diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 2aa1015..c857c0e 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -288,6 +288,34 @@ func TestNonZeroCommissionQuarantinesInstrumentAndHalts(t *testing.T) { } } +func TestEntryInstrumentPreTradeRejectsQuarantineAndCommission(t *testing.T) { + s := Scheduler{cfg: Config{RequireZeroCommission: true}} + err := s.checkEntryInstrumentBeforeOrder(domain.Instrument{ + InstrumentUID: "uid", + Ticker: "TRUR", + Enabled: true, + Quarantine: true, + Lot: 1, + MinPriceIncrement: decimal.NewFromInt(1), + Currency: "RUB", + }, domain.TradingStatusNormal) + if err == nil { + t.Fatal("expected quarantine rejection") + } + err = s.checkEntryInstrumentBeforeOrder(domain.Instrument{ + InstrumentUID: "uid", + Ticker: "TRUR", + Enabled: true, + Lot: 1, + MinPriceIncrement: decimal.NewFromInt(1), + Currency: "RUB", + ExpectedCommissionBpsPerSide: decimal.NewFromInt(1), + }, domain.TradingStatusNormal) + if err == nil || err.Error() != signalengine.ReasonCommission { + t.Fatalf("err=%v, want commission rejection", err) + } +} + func TestPreTradeDailyLossBreachHalts(t *testing.T) { ctx := context.Background() repo := testutil.NewMemoryRepository() @@ -481,13 +509,14 @@ func TestPlaceEntryRejectsWideSpreadBeforeOrder(t *testing.T) { repo := testutil.NewMemoryRepository() tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) instrument := domain.Instrument{ - InstrumentUID: "uid", - Ticker: "TRUR", - ClassCode: "TQTF", - Enabled: true, - Lot: 1, - MinPriceIncrement: decimal.RequireFromString("0.01"), - Currency: "RUB", + InstrumentUID: "uid", + Ticker: "TRUR", + ClassCode: "TQTF", + Enabled: true, + Lot: 1, + MinPriceIncrement: decimal.RequireFromString("0.01"), + Currency: "RUB", + FreeOrderLimitPerDay: -1, } if err := repo.UpsertInstrument(ctx, instrument); err != nil { t.Fatal(err) diff --git a/internal/tinvest/paper.go b/internal/tinvest/paper.go new file mode 100644 index 0000000..bdff4fc --- /dev/null +++ b/internal/tinvest/paper.go @@ -0,0 +1,85 @@ +package tinvest + +import ( + "context" + "time" + + "github.com/shopspring/decimal" + + "overnight-trading-bot/internal/domain" +) + +type PaperGateway struct { + market Gateway + fake *FakeGateway +} + +func NewPaperGateway(market Gateway) *PaperGateway { + return &PaperGateway{market: market, fake: NewFakeGateway()} +} + +func (g *PaperGateway) Fake() *FakeGateway { + if g.fake == nil { + g.fake = NewFakeGateway() + } + return g.fake +} + +func (g *PaperGateway) GetInstrument(ctx context.Context, ticker, classCode string) (domain.Instrument, error) { + if g.market != nil { + return g.market.GetInstrument(ctx, ticker, classCode) + } + return g.Fake().GetInstrument(ctx, ticker, classCode) +} + +func (g *PaperGateway) GetCandles(ctx context.Context, instrumentUID string, interval string, from, to time.Time) ([]domain.Candle, error) { + if g.market != nil { + return g.market.GetCandles(ctx, instrumentUID, interval, from, to) + } + return g.Fake().GetCandles(ctx, instrumentUID, interval, from, to) +} + +func (g *PaperGateway) GetOrderBook(ctx context.Context, instrumentUID string, depth int32) (domain.OrderBook, error) { + if g.market != nil { + return g.market.GetOrderBook(ctx, instrumentUID, depth) + } + return g.Fake().GetOrderBook(ctx, instrumentUID, depth) +} + +func (g *PaperGateway) GetTradingStatus(ctx context.Context, instrumentUID string) (domain.TradingStatus, error) { + if g.market != nil { + return g.market.GetTradingStatus(ctx, instrumentUID) + } + return g.Fake().GetTradingStatus(ctx, instrumentUID) +} + +func (g *PaperGateway) PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) { + return g.Fake().PostLimitOrder(ctx, accountID, instrumentUID, side, lots, price, clientOrderID) +} + +func (g *PaperGateway) CancelOrder(ctx context.Context, accountID, orderID string) error { + return g.Fake().CancelOrder(ctx, accountID, orderID) +} + +func (g *PaperGateway) GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error) { + return g.Fake().GetOrderState(ctx, accountID, orderID) +} + +func (g *PaperGateway) GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error) { + return g.Fake().GetActiveOrders(ctx, accountID) +} + +func (g *PaperGateway) GetPortfolio(ctx context.Context, accountID string) (domain.Portfolio, error) { + return g.Fake().GetPortfolio(ctx, accountID) +} + +func (g *PaperGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) { + return g.Fake().GetOperations(ctx, accountID, from, to) +} + +func (g *PaperGateway) GetServerTime(ctx context.Context) (time.Time, error) { + if g.market != nil { + return g.market.GetServerTime(ctx) + } + return g.Fake().GetServerTime(ctx) +} diff --git a/internal/tinvest/real.go b/internal/tinvest/real.go index 14f0fc6..cf90980 100644 --- a/internal/tinvest/real.go +++ b/internal/tinvest/real.go @@ -23,24 +23,26 @@ import ( ) type Options struct { - Token string - AccountID string - Endpoint string - AppName string - RetryCount int - RetryBackoff time.Duration - Logger *slog.Logger + Token string + AccountID string + Endpoint string + AppName string + RequestTimeout time.Duration + RetryCount int + RetryBackoff time.Duration + Logger *slog.Logger } type RealGateway struct { - client *investgo.Client - instruments *investgo.InstrumentsServiceClient - marketData *investgo.MarketDataServiceClient - orders *investgo.OrdersServiceClient - operations *investgo.OperationsServiceClient - users *investgo.UsersServiceClient - retryAttempts int - retryBackoff time.Duration + client *investgo.Client + instruments *investgo.InstrumentsServiceClient + marketData *investgo.MarketDataServiceClient + orders *investgo.OrdersServiceClient + operations *investgo.OperationsServiceClient + users *investgo.UsersServiceClient + requestTimeout time.Duration + retryAttempts int + retryBackoff time.Duration } func NewRealGateway(ctx context.Context, opts Options) (*RealGateway, error) { @@ -58,14 +60,15 @@ func NewRealGateway(ctx context.Context, opts Options) (*RealGateway, error) { return nil, err } return &RealGateway{ - client: client, - instruments: client.NewInstrumentsServiceClient(), - marketData: client.NewMarketDataServiceClient(), - orders: client.NewOrdersServiceClient(), - operations: client.NewOperationsServiceClient(), - users: client.NewUsersServiceClient(), - retryAttempts: opts.RetryCount, - retryBackoff: opts.RetryBackoff, + client: client, + instruments: client.NewInstrumentsServiceClient(), + marketData: client.NewMarketDataServiceClient(), + orders: client.NewOrdersServiceClient(), + operations: client.NewOperationsServiceClient(), + users: client.NewUsersServiceClient(), + requestTimeout: opts.RequestTimeout, + retryAttempts: opts.RetryCount, + retryBackoff: opts.RetryBackoff, }, nil } @@ -80,8 +83,10 @@ func (g *RealGateway) GetInstrument(ctx context.Context, ticker, classCode strin if err := ctx.Err(); err != nil { return domain.Instrument{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.EtfResponse, error) { - return g.instruments.EtfByTicker(ticker, classCode) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.EtfResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.EtfResponse, error) { + return g.instruments.EtfByTicker(ticker, classCode) + }) }) if err != nil { return domain.Instrument{}, err @@ -108,8 +113,10 @@ func (g *RealGateway) GetCandles(ctx context.Context, instrumentUID string, inte if err := ctx.Err(); err != nil { return nil, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetCandlesResponse, error) { - return g.marketData.GetCandles(instrumentUID, candleInterval(interval), from, to, pb.GetCandlesRequest_CANDLE_SOURCE_EXCHANGE, 0) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetCandlesResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetCandlesResponse, error) { + return g.marketData.GetCandles(instrumentUID, candleInterval(interval), from, to, pb.GetCandlesRequest_CANDLE_SOURCE_EXCHANGE, 0) + }) }) if err != nil { return nil, err @@ -136,8 +143,10 @@ func (g *RealGateway) GetOrderBook(ctx context.Context, instrumentUID string, de if err := ctx.Err(); err != nil { return domain.OrderBook{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderBookResponse, error) { - return g.marketData.GetOrderBook(instrumentUID, depth) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrderBookResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderBookResponse, error) { + return g.marketData.GetOrderBook(instrumentUID, depth) + }) }) if err != nil { return domain.OrderBook{}, err @@ -155,8 +164,10 @@ func (g *RealGateway) GetTradingStatus(ctx context.Context, instrumentUID string if err := ctx.Err(); err != nil { return domain.TradingStatusUnknown, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetTradingStatusResponse, error) { - return g.marketData.GetTradingStatus(instrumentUID) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetTradingStatusResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetTradingStatusResponse, error) { + return g.marketData.GetTradingStatus(instrumentUID) + }) }) if err != nil { return domain.TradingStatusUnknown, err @@ -181,17 +192,19 @@ func (g *RealGateway) PostLimitOrder(ctx context.Context, accountID, instrumentU if err != nil { return domain.Order{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { - return g.orders.PostOrder(&investgo.PostOrderRequest{ - InstrumentId: instrumentUID, - Quantity: lots, - Price: quotation, - Direction: direction, - AccountId: accountID, - OrderType: pb.OrderType_ORDER_TYPE_LIMIT, - OrderId: clientOrderID, - TimeInForce: pb.TimeInForceType_TIME_IN_FORCE_DAY, - PriceType: pb.PriceType_PRICE_TYPE_CURRENCY, + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PostOrderResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { + return g.orders.PostOrder(&investgo.PostOrderRequest{ + InstrumentId: instrumentUID, + Quantity: lots, + Price: quotation, + Direction: direction, + AccountId: accountID, + OrderType: pb.OrderType_ORDER_TYPE_LIMIT, + OrderId: clientOrderID, + TimeInForce: pb.TimeInForceType_TIME_IN_FORCE_DAY, + PriceType: pb.PriceType_PRICE_TYPE_CURRENCY, + }) }) }) if err != nil { @@ -204,18 +217,23 @@ func (g *RealGateway) CancelOrder(ctx context.Context, accountID, orderID string if err := ctx.Err(); err != nil { return err } - return withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error { - _, err := g.orders.CancelOrder(accountID, orderID, nil) - return err + _, err := requestWithTimeout(ctx, g.requestTimeout, func() (struct{}, error) { + return struct{}{}, withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error { + _, err := g.orders.CancelOrder(accountID, orderID, nil) + return err + }) }) + return err } func (g *RealGateway) GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error) { if err := ctx.Err(); err != nil { return domain.Order{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) { - return g.orders.GetOrderState(accountID, orderID, pb.PriceType_PRICE_TYPE_CURRENCY, nil) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrderStateResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) { + return g.orders.GetOrderState(accountID, orderID, pb.PriceType_PRICE_TYPE_CURRENCY, nil) + }) }) if err != nil { return domain.Order{}, err @@ -227,8 +245,10 @@ func (g *RealGateway) GetActiveOrders(ctx context.Context, accountID string) ([] if err := ctx.Err(); err != nil { return nil, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) { - return g.orders.GetOrders(accountID, nil) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrdersResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) { + return g.orders.GetOrders(accountID, nil) + }) }) if err != nil { return nil, err @@ -245,8 +265,10 @@ func (g *RealGateway) GetPortfolio(ctx context.Context, accountID string) (domai if err := ctx.Err(); err != nil { return domain.Portfolio{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) { - return g.operations.GetPortfolio(accountID, pb.PortfolioRequest_RUB) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PortfolioResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) { + return g.operations.GetPortfolio(accountID, pb.PortfolioRequest_RUB) + }) }) if err != nil { return domain.Portfolio{}, err @@ -258,11 +280,13 @@ func (g *RealGateway) GetOperations(ctx context.Context, accountID string, from, if err := ctx.Err(); err != nil { return nil, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) { - return g.operations.GetOperations(&investgo.GetOperationsRequest{ - AccountId: accountID, - From: from, - To: to, + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.OperationsResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) { + return g.operations.GetOperations(&investgo.GetOperationsRequest{ + AccountId: accountID, + From: from, + To: to, + }) }) }) if err != nil { @@ -319,8 +343,10 @@ func (g *RealGateway) GetServerTime(ctx context.Context) (time.Time, error) { if err := ctx.Err(); err != nil { return time.Time{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetInfoResponse, error) { - return g.users.GetInfo() + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetInfoResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetInfoResponse, error) { + return g.users.GetInfo() + }) }) if err != nil { return time.Time{}, err diff --git a/internal/tinvest/retry.go b/internal/tinvest/retry.go index 4bf2147..5337966 100644 --- a/internal/tinvest/retry.go +++ b/internal/tinvest/retry.go @@ -62,3 +62,27 @@ func retryValue[T any](ctx context.Context, attempts int, interval time.Duration } return out, nil } + +func requestWithTimeout[T any](ctx context.Context, timeout time.Duration, fn func() (T, error)) (T, error) { + if timeout <= 0 { + return fn() + } + callCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + type result struct { + value T + err error + } + done := make(chan result, 1) + go func() { + value, err := fn() + done <- result{value: value, err: err} + }() + select { + case res := <-done: + return res.value, res.err + case <-callCtx.Done(): + var zero T + return zero, callCtx.Err() + } +} diff --git a/internal/tinvest/retry_test.go b/internal/tinvest/retry_test.go index 709fc4e..621c5c5 100644 --- a/internal/tinvest/retry_test.go +++ b/internal/tinvest/retry_test.go @@ -24,6 +24,16 @@ func TestWithRetryRetriesUntilSuccess(t *testing.T) { } } +func TestRequestWithTimeoutReturnsDeadline(t *testing.T) { + _, err := requestWithTimeout(context.Background(), time.Millisecond, func() (int, error) { + time.Sleep(50 * time.Millisecond) + return 1, nil + }) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("err=%v, want DeadlineExceeded", err) + } +} + func TestWithRetryStopsOnContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() diff --git a/internal/tinvest/sandbox.go b/internal/tinvest/sandbox.go index c968650..86b9e92 100644 --- a/internal/tinvest/sandbox.go +++ b/internal/tinvest/sandbox.go @@ -43,17 +43,19 @@ func (g *SandboxGateway) PostLimitOrder(ctx context.Context, accountID, instrume if err != nil { return domain.Order{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { - return g.sandbox.PostSandboxOrder(&investgo.PostOrderRequest{ - InstrumentId: instrumentUID, - Quantity: lots, - Price: quotation, - Direction: direction, - AccountId: accountID, - OrderType: pb.OrderType_ORDER_TYPE_LIMIT, - OrderId: clientOrderID, - TimeInForce: pb.TimeInForceType_TIME_IN_FORCE_DAY, - PriceType: pb.PriceType_PRICE_TYPE_CURRENCY, + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PostOrderResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { + return g.sandbox.PostSandboxOrder(&investgo.PostOrderRequest{ + InstrumentId: instrumentUID, + Quantity: lots, + Price: quotation, + Direction: direction, + AccountId: accountID, + OrderType: pb.OrderType_ORDER_TYPE_LIMIT, + OrderId: clientOrderID, + TimeInForce: pb.TimeInForceType_TIME_IN_FORCE_DAY, + PriceType: pb.PriceType_PRICE_TYPE_CURRENCY, + }) }) }) if err != nil { @@ -66,18 +68,23 @@ func (g *SandboxGateway) CancelOrder(ctx context.Context, accountID, orderID str if err := ctx.Err(); err != nil { return err } - return withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error { - _, err := g.sandbox.CancelSandboxOrder(accountID, orderID) - return err + _, err := requestWithTimeout(ctx, g.requestTimeout, func() (struct{}, error) { + return struct{}{}, withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error { + _, err := g.sandbox.CancelSandboxOrder(accountID, orderID) + return err + }) }) + return err } func (g *SandboxGateway) GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error) { if err := ctx.Err(); err != nil { return domain.Order{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) { - return g.sandbox.GetSandboxOrderState(accountID, orderID) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrderStateResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) { + return g.sandbox.GetSandboxOrderState(accountID, orderID) + }) }) if err != nil { return domain.Order{}, err @@ -89,8 +96,10 @@ func (g *SandboxGateway) GetActiveOrders(ctx context.Context, accountID string) if err := ctx.Err(); err != nil { return nil, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) { - return g.sandbox.GetSandboxOrders(accountID) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrdersResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) { + return g.sandbox.GetSandboxOrders(accountID) + }) }) if err != nil { return nil, err @@ -107,8 +116,10 @@ func (g *SandboxGateway) GetPortfolio(ctx context.Context, accountID string) (do if err := ctx.Err(); err != nil { return domain.Portfolio{}, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) { - return g.sandbox.GetSandboxPortfolio(accountID, pb.PortfolioRequest_RUB) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PortfolioResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) { + return g.sandbox.GetSandboxPortfolio(accountID, pb.PortfolioRequest_RUB) + }) }) if err != nil { return domain.Portfolio{}, err @@ -120,11 +131,13 @@ func (g *SandboxGateway) GetOperations(ctx context.Context, accountID string, fr if err := ctx.Err(); err != nil { return nil, err } - resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) { - return g.sandbox.GetSandboxOperations(&investgo.GetOperationsRequest{ - AccountId: accountID, - From: from, - To: to, + resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.OperationsResponse, error) { + return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) { + return g.sandbox.GetSandboxOperations(&investgo.GetOperationsRequest{ + AccountId: accountID, + From: from, + To: to, + }) }) }) if err != nil {