fourth version
Deploy / Test, build and deploy (push) Failing after 3m7s

This commit is contained in:
2026-06-08 07:36:52 +00:00
parent 52a935b8b4
commit b9efa98758
20 changed files with 824 additions and 91 deletions
+2 -1
View File
@@ -76,6 +76,7 @@ APP_MODE=backtest go run ./cmd/bot
| `STRATEGY_ROLLING_SHORT` | количество торговых дней | `60` | рекомендуется `> 0` | Короткое окно статистики overnight-доходности. Больше - стабильнее оценка, но медленнее реакция; меньше - быстрее реакция, но больше шум. |
| `STRATEGY_ROLLING_LONG` | количество торговых дней | `252` | рекомендуется `>= STRATEGY_ROLLING_SHORT` и `> 0` | Длинное окно для проверки положительного долгосрочного edge и глубины backfill. Больше требует больше истории. |
| `STRATEGY_EWMA_LAMBDA` | дробь для EWMA | `0.08` | рабочий диапазон `(0, 1]`; вне диапазона EWMA-функция использует `0.08` | Вес новых наблюдений в EWMA. Больше - свежее движение влияет сильнее. |
| `STRATEGY_ALLOCATION_METHOD` | `equal_weight` | `equal_weight` | сейчас поддерживается только `equal_weight` | Метод распределения капитала между выбранными сигналами. Текущая реализация делит лимит экспозиции поровну между выбранными инструментами. |
| `STRATEGY_MIN_TSTAT_60` | decimal t-stat | `1.25` | валидации нет; обычно `>= 0` | Минимальная статистическая значимость короткого edge. Выше - меньше входов, ниже - больше входов. |
| `STRATEGY_MIN_WIN_RATE_60` | доля прибыльных overnight-дней | `0.55` | рекомендуется `0..1` | Минимальная доля положительных overnight-наблюдений. Выше - строже фильтр сигналов. |
| `STRATEGY_MIN_NET_EDGE_BPS` | bps | `10` | валидации нет; обычно `>= 0` | Минимальный ожидаемый edge после издержек. Выше - меньше, но потенциально качественнее сигналы. |
@@ -146,7 +147,7 @@ APP_MODE=backtest go run ./cmd/bot
| --- | --- | --- | --- | --- |
| `COMM_REQUIRE_ZERO_COMMISSION` | `true` или `false` | `true` | boolean | При `true` сигналы по инструментам с ожидаемой комиссией `> 0` отклоняются. |
| `COMM_QUARANTINE_ON_NONZERO` | `true` или `false` | `true` | boolean | При фактической брокерской комиссии `> 0` инструмент переводится в quarantine, а система останавливается через HALT по zero-commission policy. |
| `COMM_FREE_ORDER_COUNT_POLICY` | `submitted` | `submitted` | жёстко только `submitted` | Политика учёта бесплатных заявок: счётчик увеличивается при отправке заявки. Другие значения запрещены валидацией. |
| `COMM_FREE_ORDER_COUNT_POLICY` | `submitted` или `cancel_counts` | `submitted` | одно из двух значений | Политика учёта бесплатных заявок: `submitted` считает только отправку новой заявки, `cancel_counts` дополнительно считает успешные отмены перед repost. |
### BT
+1 -1
View File
@@ -130,7 +130,7 @@ func run() error {
MaxSpreadBps: spread,
MaxTickBps: tick,
AssumedSpreadBps: assumed,
RequireZeroCommission: *requireZeroCommission,
RequireZeroCommission: requireZeroCommission,
UseMinuteModel: *useMinuteModel,
})
result, err := engine.RunWithMinuteCandles(candles, minuteCandles)
+3
View File
@@ -244,6 +244,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
})
execEngine := execution.NewEngine(cfg.App.Mode, cfg.TInvest.AccountID, gateway, repo)
execEngine.SetMaxQuoteAge(time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second)
execEngine.SetFreeOrderCountPolicy(cfg.Commission.FreeOrderCountPolicy)
services := scheduler.Services{
Repo: repo,
Gateway: gateway,
@@ -272,6 +273,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
EntryWindowEnd: cfg.Execution.EntryWindowEnd,
NoNewEntryAfter: cfg.Execution.NoNewEntryAfter,
ExitWatchStart: cfg.Execution.ExitWatchStart,
ExitNotBefore: cfg.Execution.ExitNotBefore,
ExitWindowStart: cfg.Execution.ExitWindowStart,
ExitWindowEnd: cfg.Execution.ExitWindowEnd,
HardExitDeadline: cfg.Execution.HardExitDeadline,
@@ -287,6 +289,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
APIOutageHalt: time.Duration(cfg.Risk.APIOutageHaltSec) * time.Second,
RequireZeroCommission: cfg.Commission.RequireZeroCommission,
QuarantineOnNonZero: cfg.Commission.QuarantineOnNonZero,
FreeOrderCountPolicy: cfg.Commission.FreeOrderCountPolicy,
ReconciliationInterval: 5 * time.Minute,
MaxOpenPositions: minPositive(cfg.Strategy.MaxPositions, cfg.Risk.MaxOpenPositions),
}, services)
+38
View File
@@ -0,0 +1,38 @@
package backtest
import (
"testing"
"github.com/shopspring/decimal"
)
func TestRequireZeroCommissionDefaultDoesNotOverrideExplicitFalse(t *testing.T) {
defaultEngine := New(Config{})
if !defaultEngine.requireZeroCommission() {
t.Fatal("default require_zero_commission should be true")
}
requireZero := false
explicitEngine := New(Config{RequireZeroCommission: &requireZero})
if explicitEngine.requireZeroCommission() {
t.Fatal("explicit require_zero_commission=false was overridden")
}
}
func TestAssumedSpreadUsesFundTypeSpecificDefaults(t *testing.T) {
engine := New(Config{
AssumedSpreadBps: decimal.NewFromInt(20),
InstrumentFundTypes: map[string]string{
"mm": "money_market",
"eq": "equity",
},
})
if got := engine.assumedSpreadBps("mm"); !got.Equal(decimal.NewFromInt(5)) {
t.Fatalf("money market spread=%s, want 5", got)
}
if got := engine.assumedSpreadBps("eq"); !got.Equal(decimal.NewFromInt(25)) {
t.Fatalf("equity spread=%s, want 25", got)
}
if got := engine.assumedSpreadBps("unknown"); !got.Equal(decimal.NewFromInt(20)) {
t.Fatalf("default spread=%s, want 20", got)
}
}
+78 -7
View File
@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/shopspring/decimal"
@@ -33,8 +34,11 @@ type Config struct {
MinNetEdgeBps decimal.Decimal
MinADVRUB decimal.Decimal
MaxSpreadBps decimal.Decimal
MaxSpreadBpsMoneyMarket decimal.Decimal
MaxSpreadBpsBondFunds decimal.Decimal
MaxSpreadBpsEquityFunds decimal.Decimal
MaxTickBps decimal.Decimal
RequireZeroCommission bool
RequireZeroCommission *bool
MaxPositions int
MaxPositionPct decimal.Decimal
MaxTotalExposurePct decimal.Decimal
@@ -43,6 +47,8 @@ type Config struct {
RiskBudgetPct decimal.Decimal
MinOrderNotionalRUB decimal.Decimal
AssumedSpreadBps decimal.Decimal
AssumedSpreadBpsByFundType map[string]decimal.Decimal
InstrumentFundTypes map[string]string
AssumedTickBps decimal.Decimal
Lot int64
UseMinuteModel bool
@@ -120,6 +126,15 @@ func (cfg Config) withDefaults() Config {
if cfg.MaxSpreadBps.IsZero() {
cfg.MaxSpreadBps = decimal.NewFromInt(20)
}
if cfg.MaxSpreadBpsMoneyMarket.IsZero() {
cfg.MaxSpreadBpsMoneyMarket = decimal.NewFromInt(5)
}
if cfg.MaxSpreadBpsBondFunds.IsZero() {
cfg.MaxSpreadBpsBondFunds = decimal.NewFromInt(10)
}
if cfg.MaxSpreadBpsEquityFunds.IsZero() {
cfg.MaxSpreadBpsEquityFunds = decimal.NewFromInt(25)
}
if cfg.MaxTickBps.IsZero() {
cfg.MaxTickBps = decimal.NewFromInt(10)
}
@@ -132,8 +147,9 @@ func (cfg Config) withDefaults() Config {
if cfg.AssumedTickBps.IsZero() {
cfg.AssumedTickBps = cfg.MaxTickBps
}
if !cfg.RequireZeroCommission && cfg.CommissionRoundtripBps.IsZero() {
cfg.RequireZeroCommission = true
if cfg.RequireZeroCommission == nil {
requireZero := true
cfg.RequireZeroCommission = &requireZero
}
if cfg.MaxPositions == 0 {
cfg.MaxPositions = 5
@@ -260,7 +276,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can
Lots: lots,
Notional: notional,
NetPnL: pnl,
SpreadBps: e.cfg.AssumedSpreadBps,
SpreadBps: c.spreadBps,
SlippageBps: e.cfg.EntrySlippageBps.Add(e.cfg.ExitSlippageBps),
OvernightGap: c.overnightGap,
CapacityRUB: capacity,
@@ -355,6 +371,7 @@ type candidate struct {
buy decimal.Decimal
sell decimal.Decimal
netEdge decimal.Decimal
spreadBps decimal.Decimal
adv decimal.Decimal
q05Abs decimal.Decimal
overnightGap decimal.Decimal
@@ -381,7 +398,8 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
return candidate{}, false, nil
}
rawEdge := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
cost := e.cfg.AssumedSpreadBps.
spreadBps := e.assumedSpreadBps(instrumentUID)
cost := spreadBps.
Add(e.cfg.EntrySlippageBps).
Add(e.cfg.ExitSlippageBps).
Add(e.cfg.CommissionRoundtripBps).
@@ -389,7 +407,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
netEdge := rawEdge.Sub(cost)
adv := features.ADV(history, e.cfg.Lot, 20)
switch {
case e.cfg.RequireZeroCommission && e.cfg.CommissionRoundtripBps.IsPositive():
case e.requireZeroCommission() && e.cfg.CommissionRoundtripBps.IsPositive():
return candidate{}, false, nil
case !decimal.NewFromFloat(short.Mean).IsPositive() || !decimal.NewFromFloat(long.Mean).IsPositive():
return candidate{}, false, nil
@@ -399,7 +417,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
return candidate{}, false, nil
case netEdge.LessThan(e.cfg.MinNetEdgeBps):
return candidate{}, false, nil
case e.cfg.AssumedSpreadBps.GreaterThan(e.cfg.MaxSpreadBps):
case spreadBps.GreaterThan(e.maxSpreadBps(instrumentUID)):
return candidate{}, false, nil
case e.cfg.AssumedTickBps.GreaterThan(e.cfg.MaxTickBps):
return candidate{}, false, nil
@@ -425,6 +443,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
buy: buy,
sell: sell,
netEdge: netEdge,
spreadBps: spreadBps,
adv: adv,
q05Abs: q05Abs,
overnightGap: gap,
@@ -432,6 +451,58 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
}, true, nil
}
func (e Engine) requireZeroCommission() bool {
return e.cfg.RequireZeroCommission != nil && *e.cfg.RequireZeroCommission
}
func (e Engine) assumedSpreadBps(instrumentUID string) decimal.Decimal {
fundType := normalizedFundType(e.cfg.InstrumentFundTypes[instrumentUID])
if !fundType.IsZeroValue {
if spread, ok := e.cfg.AssumedSpreadBpsByFundType[fundType.Key]; ok {
return spread
}
return e.maxSpreadBpsForFundType(fundType.Raw)
}
return e.cfg.AssumedSpreadBps
}
func (e Engine) maxSpreadBps(instrumentUID string) decimal.Decimal {
fundType := normalizedFundType(e.cfg.InstrumentFundTypes[instrumentUID])
if fundType.IsZeroValue {
return e.cfg.MaxSpreadBps
}
return e.maxSpreadBpsForFundType(fundType.Raw)
}
func (e Engine) maxSpreadBpsForFundType(fundType string) decimal.Decimal {
switch {
case strings.Contains(fundType, "money"):
return e.cfg.MaxSpreadBpsMoneyMarket
case strings.Contains(fundType, "bond"):
return e.cfg.MaxSpreadBpsBondFunds
case strings.Contains(fundType, "equity"):
return e.cfg.MaxSpreadBpsEquityFunds
default:
return e.cfg.MaxSpreadBps
}
}
type normalizedType struct {
Raw string
Key string
IsZeroValue bool
}
func normalizedFundType(raw string) normalizedType {
raw = strings.ToLower(strings.TrimSpace(raw))
if raw == "" {
return normalizedType{IsZeroValue: true}
}
key := strings.ReplaceAll(raw, "-", "_")
key = strings.ReplaceAll(key, " ", "_")
return normalizedType{Raw: raw, Key: key}
}
func prepareCandles(candlesByInstrument map[string][]domain.Candle) map[string][]domain.Candle {
prepared := make(map[string][]domain.Candle, len(candlesByInstrument))
for instrumentUID, candles := range candlesByInstrument {
+14 -2
View File
@@ -71,6 +71,7 @@ type StrategyConfig struct {
RollingShort int `env:"ROLLING_SHORT" envDefault:"60"`
RollingLong int `env:"ROLLING_LONG" envDefault:"252"`
EWMALambda float64 `env:"EWMA_LAMBDA" envDefault:"0.08"`
AllocationMethod string `env:"ALLOCATION_METHOD" envDefault:"equal_weight"`
MinTStat60 decimal.Decimal `env:"MIN_TSTAT_60" envDefault:"1.25"`
MinWinRate60 decimal.Decimal `env:"MIN_WIN_RATE_60" envDefault:"0.55"`
MinNetEdgeBps decimal.Decimal `env:"MIN_NET_EDGE_BPS" envDefault:"10"`
@@ -203,8 +204,19 @@ func (c *Config) Validate() error {
if c.Risk.CommissionToleranceRUB.IsNegative() {
return errors.New("RISK_COMMISSION_TOLERANCE_RUB must be non-negative")
}
if c.Commission.FreeOrderCountPolicy != "submitted" {
return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted, got %q", c.Commission.FreeOrderCountPolicy)
if c.Commission.FreeOrderCountPolicy == "" {
c.Commission.FreeOrderCountPolicy = "submitted"
}
switch c.Commission.FreeOrderCountPolicy {
case "submitted", "cancel_counts":
default:
return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted or cancel_counts, got %q", c.Commission.FreeOrderCountPolicy)
}
if c.Strategy.AllocationMethod == "" {
c.Strategy.AllocationMethod = "equal_weight"
}
if c.Strategy.AllocationMethod != "equal_weight" {
return fmt.Errorf("STRATEGY_ALLOCATION_METHOD must be equal_weight, got %q", c.Strategy.AllocationMethod)
}
if err := c.validateWindows(); err != nil {
return err
+9
View File
@@ -19,6 +19,14 @@ func TestValidateRequiresAccountIDForBrokerModes(t *testing.T) {
}
}
func TestValidateAllowsCancelCountsFreeOrderPolicy(t *testing.T) {
cfg := minimalBrokerConfig(domain.ModeSandbox)
cfg.Commission.FreeOrderCountPolicy = "cancel_counts"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate cancel_counts: %v", err)
}
}
func minimalBrokerConfig(mode domain.Mode) Config {
return Config{
App: AppConfig{
@@ -44,6 +52,7 @@ func minimalBrokerConfig(mode domain.Mode) Config {
QuoteDepth: 20,
OrderPollIntervalMS: 500,
},
Strategy: StrategyConfig{AllocationMethod: "equal_weight"},
Risk: RiskConfig{
APIOutageHaltSec: 180,
ReconciliationWindowHours: 72,
+36 -4
View File
@@ -17,6 +17,11 @@ import (
var ErrBrokerOrdersDisabled = errors.New("broker orders are disabled for current mode")
var ErrEmptyOrderBook = errors.New("order book has no usable bid/ask")
const (
FreeOrderPolicySubmitted = "submitted"
FreeOrderPolicyCancelCounts = "cancel_counts"
)
type Gateway interface {
PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error)
CancelOrder(ctx context.Context, accountID, orderID string) error
@@ -29,6 +34,7 @@ type Engine struct {
gateway Gateway
store repository.Repository
maxQuoteAge time.Duration
freeOrderCountPolicy string
mu sync.Map
}
@@ -44,13 +50,22 @@ type MonitorConfig struct {
}
func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine {
return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store}
return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store, freeOrderCountPolicy: FreeOrderPolicySubmitted}
}
func (e *Engine) SetMaxQuoteAge(maxQuoteAge time.Duration) {
e.maxQuoteAge = maxQuoteAge
}
func (e *Engine) SetFreeOrderCountPolicy(policy string) {
switch policy {
case FreeOrderPolicyCancelCounts:
e.freeOrderCountPolicy = policy
default:
e.freeOrderCountPolicy = FreeOrderPolicySubmitted
}
}
func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) {
if err := e.checkQuoteFresh(book); err != nil {
return domain.Order{}, err
@@ -251,7 +266,15 @@ func (e *Engine) Cancel(ctx context.Context, order domain.Order) error {
return err
}
if e.store != nil {
return e.store.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON)
return e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
if err := repo.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON); err != nil {
return err
}
if e.cancelCountsAsFreeOrder() {
return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1)
}
return nil
})
}
return nil
}
@@ -485,12 +508,21 @@ func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, ins
if err != nil {
return err
}
if instrument.FreeOrderLimitPerDay-sent < 1 {
return fmt.Errorf("%w: %s remaining=0", risk.ErrFreeOrderBudget, instrument.InstrumentUID)
needed := 1
if e.cancelCountsAsFreeOrder() {
needed = 2
}
remaining := instrument.FreeOrderLimitPerDay - sent
if remaining < needed {
return fmt.Errorf("%w: %s remaining=%d needed=%d", risk.ErrFreeOrderBudget, instrument.InstrumentUID, remaining, needed)
}
return nil
}
func (e *Engine) cancelCountsAsFreeOrder() bool {
return e.freeOrderCountPolicy == FreeOrderPolicyCancelCounts
}
func (e *Engine) checkQuoteFresh(book domain.OrderBook) error {
if e.maxQuoteAge <= 0 {
return nil
+34
View File
@@ -96,6 +96,40 @@ func TestPaperPlaceEntryFillsAndCountsSubmittedOrder(t *testing.T) {
}
}
func TestCancelCountsAsFreeOrderWhenPolicyEnabled(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
engine.SetFreeOrderCountPolicy(FreeOrderPolicyCancelCounts)
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
order, err := engine.PlaceLimit(ctx, domain.Order{
ClientOrderID: "order-1",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: tradeDate,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
LimitPrice: decimal.NewFromInt(100),
QuantityLots: 1,
Status: domain.OrderStatusNew,
AttemptNo: 1,
})
if err != nil {
t.Fatal(err)
}
if err := engine.Cancel(ctx, order); err != nil {
t.Fatal(err)
}
sent, err := repo.GetFreeOrdersSent(ctx, tradeDate, "uid")
if err != nil {
t.Fatal(err)
}
if sent != 2 {
t.Fatalf("free order counter=%d, want submit+cancel", sent)
}
}
func TestPlaceEntryRejectsStaleQuote(t *testing.T) {
ctx := context.Background()
engine := NewEngine(domain.ModeSandbox, "account", tinvest.NewFakeGateway(), testutil.NewMemoryRepository())
+22 -3
View File
@@ -40,7 +40,8 @@ func NewPipeline(repo repository.Repository, cfg PipelineConfig) Pipeline {
func (p Pipeline) Recompute(ctx context.Context, instrument domain.Instrument, tradeDate time.Time, spread SpreadResult) (domain.FeatureSet, error) {
from := tradeDate.AddDate(0, 0, -p.cfg.RollingLong-5)
candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, tradeDate)
to := dateOnly(tradeDate).AddDate(0, 0, -1)
candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, to)
if err != nil {
return domain.FeatureSet{}, err
}
@@ -74,8 +75,9 @@ func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrume
if lookback <= 0 {
lookback = defaultIntervalVolumeLookback
}
from := window.Start.On(date.AddDate(0, 0, -lookback), loc).UTC()
to := window.End.On(date, loc).UTC()
toDate := dateOnly(date).AddDate(0, 0, -1)
from := window.Start.On(toDate.AddDate(0, 0, -lookback+1), loc).UTC()
to := window.End.On(toDate, loc).UTC()
candles, err := p.repo.ListMinuteCandles(ctx, instrument.InstrumentUID, from, to)
if err != nil {
return decimal.Zero, err
@@ -84,6 +86,7 @@ func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrume
}
func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate time.Time, spread SpreadResult, cfg PipelineConfig, entryVolume, exitVolume decimal.Decimal) (domain.FeatureSet, error) {
candles = historicalDailyCandles(candles, tradeDate)
if len(candles) < 2 {
return domain.FeatureSet{}, fmt.Errorf("need at least 2 candles, got %d", len(candles))
}
@@ -138,6 +141,22 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti
}, nil
}
func historicalDailyCandles(candles []domain.Candle, tradeDate time.Time) []domain.Candle {
tradeDay := dateOnly(tradeDate)
out := make([]domain.Candle, 0, len(candles))
for _, candle := range candles {
if dateOnly(candle.TradeDate).Before(tradeDay) {
out = append(out, candle)
}
}
return out
}
func dateOnly(ts time.Time) time.Time {
year, month, day := ts.UTC().Date()
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
}
func IntervalVolume(candles []domain.Candle, lot int64) decimal.Decimal {
if lot <= 0 {
return decimal.Zero
+37
View File
@@ -1,12 +1,14 @@
package features
import (
"context"
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/timeutil"
)
@@ -74,6 +76,41 @@ func TestAverageIntervalVolumeUsesExecutionWindowsAcrossDays(t *testing.T) {
}
}
func TestRecomputeExcludesTradeDateDailyCandle(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
start := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)
var candles []domain.Candle
for i := 0; i < 6; i++ {
closePrice := decimal.NewFromInt(100)
if i == 5 {
closePrice = decimal.NewFromInt(100000)
}
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: decimal.NewFromInt(100),
Close: closePrice,
VolumeLots: decimal.NewFromInt(1),
})
}
if err := repo.UpsertDailyCandles(ctx, candles); err != nil {
t.Fatal(err)
}
pipeline := NewPipeline(repo, PipelineConfig{
RollingShort: 2,
RollingLong: 2,
EWMALambda: 0.08,
})
got, err := pipeline.Recompute(ctx, domain.Instrument{InstrumentUID: "uid", Lot: 1}, start.AddDate(0, 0, 5), SpreadResult{})
if err != nil {
t.Fatal(err)
}
if !got.ADV20.Equal(decimal.NewFromInt(100)) {
t.Fatalf("ADV20=%s, want tradeDate candle excluded", got.ADV20)
}
}
func mustTOD(raw string) timeutil.TimeOfDay {
tod, err := timeutil.ParseTimeOfDay(raw)
if err != nil {
+60 -4
View File
@@ -1,9 +1,11 @@
package logging
import (
"fmt"
"io"
"log/slog"
"os"
"regexp"
"strings"
)
@@ -22,7 +24,10 @@ func New(level string, out io.Writer) *slog.Logger {
default:
slogLevel = slog.LevelInfo
}
return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{Level: slogLevel}))
return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{
Level: slogLevel,
ReplaceAttr: redactAttr,
}))
}
type SDKLogger struct {
@@ -31,18 +36,69 @@ type SDKLogger struct {
func (l SDKLogger) Infof(template string, args ...any) {
if l.Logger != nil {
l.Logger.Info(template, "args", args)
l.Logger.Info(RedactString(template), "args", redactArgs(args))
}
}
func (l SDKLogger) Errorf(template string, args ...any) {
if l.Logger != nil {
l.Logger.Error(template, "args", args)
l.Logger.Error(RedactString(template), "args", redactArgs(args))
}
}
func (l SDKLogger) Fatalf(template string, args ...any) {
if l.Logger != nil {
l.Logger.Error(template, "args", args)
l.Logger.Error(RedactString(template), "args", redactArgs(args))
}
}
var sensitiveStringPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)((?:account[_-]?id|token)\s*[:=]\s*)("[^"]+"|'[^']+'|[^\s,}]+)`),
regexp.MustCompile(`(?i)("(?:accountId|account_id|token)"\s*:\s*)("[^"]*"|null)`),
}
func redactAttr(_ []string, attr slog.Attr) slog.Attr {
if attr.Value.Kind() == slog.KindString {
attr.Value = slog.StringValue(RedactString(attr.Value.String()))
}
return attr
}
func redactArgs(args []any) []any {
out := make([]any, len(args))
for i, arg := range args {
out[i] = redactAny(arg)
}
return out
}
func redactAny(value any) any {
switch typed := value.(type) {
case string:
return RedactString(typed)
case []string:
out := make([]string, len(typed))
for i, item := range typed {
out[i] = RedactString(item)
}
return out
case []any:
out := make([]any, len(typed))
for i, item := range typed {
out[i] = redactAny(item)
}
return out
case fmt.Stringer:
return RedactString(typed.String())
default:
return value
}
}
func RedactString(raw string) string {
redacted := raw
for _, pattern := range sensitiveStringPatterns {
redacted = pattern.ReplaceAllString(redacted, `${1}"[REDACTED]"`)
}
return redacted
}
+57
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/shopspring/decimal"
@@ -18,6 +19,7 @@ import (
var defaultCommissionTolerance = decimal.RequireFromString("0.01")
type Engine struct {
mu *sync.Mutex
repo repository.Repository
gateway tinvest.Gateway
accountID string
@@ -31,6 +33,7 @@ type Engine struct {
func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine {
return Engine{
mu: &sync.Mutex{},
repo: repo,
gateway: gateway,
accountID: accountID,
@@ -64,6 +67,10 @@ func (e Engine) WithCommissionPolicy(requireZero, quarantineOnNonZero bool, tole
}
func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
if e.mu != nil {
e.mu.Lock()
defer e.mu.Unlock()
}
localOrders, err := e.repo.ListActiveOrders(ctx, e.accountIDHash)
if err != nil {
return nil, err
@@ -150,6 +157,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
})
}
}
diffs = append(diffs, compareCash(localPositions, portfolio, e.commissionTolerance)...)
from := now.Add(-e.window)
recentOrders, err := e.repo.ListOrders(ctx, e.accountIDHash, from, now)
if err != nil {
@@ -204,6 +212,55 @@ func compareOperations(orders []domain.Order, operations []domain.Operation) []d
return compareOperationsWithPolicy(orders, operations, false, defaultCommissionTolerance)
}
func compareCash(localPositions []domain.Position, portfolio domain.Portfolio, tolerance decimal.Decimal) []domain.ReconciliationDiff {
if tolerance.IsNegative() {
tolerance = decimal.Zero
}
expectedCash, ok := expectedCashFromLocalPositions(localPositions, portfolio)
if !ok {
return nil
}
diff := money.Abs(expectedCash.Sub(portfolio.Cash))
if diff.LessThanOrEqual(tolerance) {
return nil
}
return []domain.ReconciliationDiff{{
Kind: "cash_mismatch",
Message: fmt.Sprintf("expected cash=%s broker cash=%s diff=%s", expectedCash.StringFixed(2), portfolio.Cash.StringFixed(2), diff.StringFixed(2)),
Critical: true,
}}
}
func expectedCashFromLocalPositions(localPositions []domain.Position, portfolio domain.Portfolio) (decimal.Decimal, bool) {
if !portfolio.Equity.IsPositive() {
return decimal.Zero, false
}
if len(localPositions) == 0 {
if len(portfolio.Holdings) != 0 {
return decimal.Zero, false
}
return portfolio.Equity, true
}
holdingByInstrument := make(map[string]domain.Holding, len(portfolio.Holdings))
for _, holding := range portfolio.Holdings {
holdingByInstrument[holding.InstrumentUID] = holding
}
positionMarketValue := decimal.Zero
for _, pos := range localPositions {
if pos.Lots <= 0 {
continue
}
holding, ok := holdingByInstrument[pos.InstrumentUID]
if !ok || holding.QuantityLots <= 0 || !holding.MarketValue.IsPositive() {
return decimal.Zero, false
}
positionMarketValue = positionMarketValue.Add(holding.MarketValue.
Mul(decimal.NewFromInt(pos.Lots)).
Div(decimal.NewFromInt(holding.QuantityLots)))
}
return portfolio.Equity.Sub(positionMarketValue), true
}
func compareOperationsWithPolicy(orders []domain.Order, operations []domain.Operation, requireZeroCommission bool, commissionTolerance decimal.Decimal) []domain.ReconciliationDiff {
var diffs []domain.ReconciliationDiff
if commissionTolerance.IsNegative() {
+34
View File
@@ -170,3 +170,37 @@ func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) {
}
}
}
func TestReconciliationFindsCashMismatch(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
if err := repo.UpsertPosition(ctx, domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: time.Now().UTC(),
Lots: 2,
Status: domain.PositionHoldingOvernight,
}); err != nil {
t.Fatal(err)
}
gateway.Portfolio = domain.Portfolio{
Equity: decimal.NewFromInt(1000),
Cash: decimal.NewFromInt(700),
Holdings: []domain.Holding{{
InstrumentUID: "uid",
QuantityLots: 2,
MarketValue: decimal.NewFromInt(200),
}},
}
diffs, err := New(repo, gateway, "account", "hash").Run(ctx)
if err != nil {
t.Fatal(err)
}
for _, diff := range diffs {
if diff.Kind == "cash_mismatch" && diff.Critical {
return
}
}
t.Fatalf("missing cash_mismatch in %+v", diffs)
}
+142 -10
View File
@@ -45,6 +45,7 @@ type Config struct {
EntryWindowEnd timeutil.TimeOfDay
NoNewEntryAfter timeutil.TimeOfDay
ExitWatchStart timeutil.TimeOfDay
ExitNotBefore timeutil.TimeOfDay
ExitWindowStart timeutil.TimeOfDay
ExitWindowEnd timeutil.TimeOfDay
HardExitDeadline timeutil.TimeOfDay
@@ -60,6 +61,7 @@ type Config struct {
APIOutageHalt time.Duration
RequireZeroCommission bool
QuarantineOnNonZero bool
FreeOrderCountPolicy string
ReconciliationInterval time.Duration
MaxOpenPositions int
}
@@ -133,6 +135,13 @@ func (s *Scheduler) Step(ctx context.Context) error {
return err
}
now := s.clock.Now().In(s.cfg.Location)
reported, err := s.sendMissedDailyReport(ctx, now)
if err != nil {
return err
}
if reported {
return nil
}
phase := s.phase(now)
switch phase {
case domain.StateWaitExitWindow:
@@ -158,10 +167,14 @@ func (s *Scheduler) Step(ctx context.Context) error {
func (s Scheduler) phase(now time.Time) domain.SystemState {
tod := sinceMidnight(now)
exitWindowStart := s.cfg.ExitWindowStart.Duration
if s.cfg.ExitNotBefore.Duration > exitWindowStart {
exitWindowStart = s.cfg.ExitNotBefore.Duration
}
switch {
case tod >= s.cfg.ExitWatchStart.Duration && tod < s.cfg.ExitWindowStart.Duration:
case tod >= s.cfg.ExitWatchStart.Duration && tod < exitWindowStart:
return domain.StateWaitExitWindow
case tod >= s.cfg.ExitWindowStart.Duration && tod < s.cfg.ExitWindowEnd.Duration:
case tod >= exitWindowStart && tod < s.cfg.ExitWindowEnd.Duration:
return domain.StatePlaceExitOrders
case tod >= s.cfg.ExitWindowEnd.Duration && tod < s.cfg.HardExitDeadline.Duration:
return domain.StateMonitorExitOrders
@@ -463,7 +476,7 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
}
continue
}
pre, err := s.preTradeCheck(ctx, now, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt)
pre, err := s.preTradeCheck(ctx, now, sig.InstrumentUID, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt)
if err != nil {
return err
}
@@ -585,7 +598,7 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
if !ok {
return fmt.Errorf("instrument %s is not in registry", pos.InstrumentUID)
}
if _, err := s.svc.FreeOrders.Check(ctx, exitTradeDate, instrument, s.cfg.MaxExitOrderAttempts); err != nil {
if _, err := s.svc.FreeOrders.Check(ctx, exitTradeDate, instrument, s.orderBudgetNeededForAttempts(s.cfg.MaxExitOrderAttempts)); err != nil {
if insertErr := s.recordPreTradeReject(ctx, pos.InstrumentUID, err.Error(), `{"reason":"free_order_budget_insufficient"}`); insertErr != nil {
return insertErr
}
@@ -609,7 +622,7 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
if err != nil {
return err
}
pre, err := s.preTradeCheck(ctx, now, portfolio, len(positionsList), tradingStatus, book.ReceivedAt)
pre, err := s.preTradeCheck(ctx, now, pos.InstrumentUID, portfolio, len(positionsList), tradingStatus, book.ReceivedAt)
if err != nil {
return err
}
@@ -722,12 +735,47 @@ func (s *Scheduler) reconcileAndReport(ctx context.Context, now time.Time) error
if err := s.transitionTo(ctx, domain.StateReconcile); err != nil {
return err
}
if s.cfg.Mode.AllowsBrokerOrders() {
if err := s.reconcileCritical(ctx, "reconciliation_critical"); err != nil {
return err
}
}
return s.sendDailyReport(ctx, now, "ok")
}
func (s *Scheduler) sendMissedDailyReport(ctx context.Context, now time.Time) (bool, error) {
if s.svc.Repo == nil || !s.hasStateMachine() {
return false, nil
}
tod := sinceMidnight(now)
if tod < s.cfg.EntrySignalTime.Duration {
return false, nil
}
phase := s.phase(now)
if phase == domain.StateReconcile || phase == domain.StateReport {
return false, nil
}
state, halted, _, err := s.svc.Repo.GetSystemState(ctx)
if err != nil {
return false, err
}
if halted || state == domain.StateHalted {
return false, nil
}
if state != domain.StateInit && state != domain.StateSleep {
return false, nil
}
tradeDate := tradingDate(now)
sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash)
if err != nil {
return false, err
}
if sent {
return false, nil
}
return true, s.reconcileAndReport(ctx, now)
}
func (s *Scheduler) sendDailyReport(ctx context.Context, now time.Time, riskStatus string) error {
tradeDate := tradingDate(now)
sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash)
@@ -1081,7 +1129,7 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order
if err != nil {
return err
}
pre, err := s.preTradeCheck(ctx, now, portfolio, len(openPositions), tradingStatus, book.ReceivedAt)
pre, err := s.preTradeCheck(ctx, now, order.InstrumentUID, portfolio, len(openPositions), tradingStatus, book.ReceivedAt)
if err != nil {
return err
}
@@ -1092,12 +1140,19 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order
return nil
}
func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) {
func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, instrumentUID string, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) {
metrics, err := s.riskMetrics(ctx, now, portfolio)
if err != nil {
if haltErr := s.halt(ctx, "database_unavailable", fmt.Sprintf("pre-trade risk metrics unavailable: %s", err), instrumentUID); haltErr != nil {
return risk.PreTradeResult{}, fmt.Errorf("database_unavailable: %w; halt failed: %v", err, haltErr)
}
return risk.PreTradeResult{Allowed: false, Reason: "database_unavailable"}, fmt.Errorf("%w: database_unavailable", statemachine.ErrSystemHalted)
}
unknownOrder, unknownHolding, err := s.unknownBrokerState(ctx, portfolio)
if err != nil {
return risk.PreTradeResult{}, err
}
return s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
result := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
Portfolio: portfolio,
OpenPositions: openPositions,
DailyPnL: metrics.dailyPnL,
@@ -1108,7 +1163,73 @@ func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, portfolio d
QuoteReceivedAt: quoteReceivedAt,
Now: now.UTC(),
MarketClose: s.marketCloseOn(now),
}), nil
UnknownBrokerOrder: unknownOrder,
UnknownBrokerHolding: unknownHolding,
})
if !result.Allowed && isHardHaltPreTradeReason(result.Reason) {
if err := s.halt(ctx, result.Reason, fmt.Sprintf("pre-trade hard limit breached: %s", result.Reason), instrumentUID); err != nil {
return result, err
}
return result, fmt.Errorf("%w: %s", statemachine.ErrSystemHalted, result.Reason)
}
return result, nil
}
func (s Scheduler) unknownBrokerState(ctx context.Context, portfolio domain.Portfolio) (bool, bool, error) {
if !s.cfg.Mode.AllowsBrokerOrders() {
return false, false, nil
}
localOrders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
if err != nil {
return false, false, err
}
localByBroker := make(map[string]struct{}, len(localOrders))
for _, order := range localOrders {
if order.BrokerOrderID != "" {
localByBroker[order.BrokerOrderID] = struct{}{}
}
}
brokerOrders, err := s.svc.Gateway.GetActiveOrders(ctx, s.svc.AccountID)
if err != nil {
return false, false, err
}
for _, brokerOrder := range brokerOrders {
if brokerOrder.BrokerOrderID == "" {
continue
}
if _, ok := localByBroker[brokerOrder.BrokerOrderID]; !ok {
return true, false, nil
}
}
localPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return false, false, err
}
localLots := make(map[string]int64, len(localPositions))
for _, pos := range localPositions {
localLots[pos.InstrumentUID] += pos.Lots
}
for _, holding := range portfolio.Holdings {
if holding.QuantityLots > 0 && localLots[holding.InstrumentUID] == 0 {
return false, true, nil
}
}
return false, false, nil
}
func isHardHaltPreTradeReason(reason string) bool {
switch reason {
case "database_unavailable",
"unknown_broker_order",
"unknown_broker_position",
"trading_status_unknown_before_order",
"max_daily_loss",
"max_weekly_loss",
"max_monthly_drawdown":
return true
default:
return false
}
}
type preTradeMetrics struct {
@@ -1255,13 +1376,24 @@ func repostAfter(now, deadline time.Time, attempts int, poll time.Duration) time
}
func (s Scheduler) maxOrderAttemptsPerTrade() int {
needed := s.cfg.MaxEntryOrderAttempts + s.cfg.MaxExitOrderAttempts
needed := s.orderBudgetNeededForAttempts(s.cfg.MaxEntryOrderAttempts) + s.orderBudgetNeededForAttempts(s.cfg.MaxExitOrderAttempts)
if needed <= 0 {
return 1
}
return needed
}
func (s Scheduler) orderBudgetNeededForAttempts(attempts int) int {
if attempts <= 0 {
attempts = 1
}
needed := attempts
if s.cfg.FreeOrderCountPolicy == execution.FreeOrderPolicyCancelCounts {
needed += attempts - 1
}
return needed
}
func isSizingSkipReason(reason string) bool {
return reason == "lots_below_one" || reason == "min_order_notional"
}
+102
View File
@@ -2,6 +2,7 @@ package scheduler
import (
"context"
"errors"
"testing"
"time"
@@ -57,6 +58,33 @@ func TestPhaseUsesMoscowWindows(t *testing.T) {
}
}
func TestPhaseHonorsExitNotBeforeWhenWindowStartsEarlier(t *testing.T) {
loc := time.FixedZone("MSK", 3*60*60)
s := Scheduler{cfg: Config{
Location: loc,
EntrySignalTime: mustTOD("18:10:00"),
ExitWatchStart: mustTOD("09:50:00"),
ExitNotBefore: mustTOD("10:03:00"),
ExitWindowStart: mustTOD("10:00:00"),
ExitWindowEnd: mustTOD("10:25:00"),
HardExitDeadline: mustTOD("10:45:00"),
}}
at, err := time.Parse(time.RFC3339, "2026-06-06T10:01:00+03:00")
if err != nil {
t.Fatal(err)
}
if got := s.phase(at.In(loc)); got != domain.StateWaitExitWindow {
t.Fatalf("phase before ExitNotBefore=%s, want WAIT_EXIT_WINDOW", got)
}
at, err = time.Parse(time.RFC3339, "2026-06-06T10:04:00+03:00")
if err != nil {
t.Fatal(err)
}
if got := s.phase(at.In(loc)); got != domain.StatePlaceExitOrders {
t.Fatalf("phase after ExitNotBefore=%s, want PLACE_EXIT_ORDERS", got)
}
}
func TestInfrastructureOutageRequiresThreshold(t *testing.T) {
gateway := tinvest.NewFakeGateway()
gateway.ServerTime = time.Now().UTC().Add(-10 * time.Second)
@@ -260,6 +288,80 @@ func TestNonZeroCommissionQuarantinesInstrumentAndHalts(t *testing.T) {
}
}
func TestPreTradeDailyLossBreachHalts(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
now := time.Date(2026, 6, 8, 18, 20, 0, 0, time.UTC)
closedAt := now.Add(-time.Hour)
if err := repo.UpsertPosition(ctx, domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: tradingDate(now),
Status: domain.PositionExitFilled,
NetPnL: decimal.NewFromInt(-200),
ClosedAt: &closedAt,
}); err != nil {
t.Fatal(err)
}
notifier := &countNotifier{}
s := Scheduler{
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
svc: Services{
Repo: repo,
Risk: risk.NewManager(repo, risk.ManagerConfig{MaxDailyLossPct: decimal.RequireFromString("0.01")}),
Notifier: notifier,
AccountIDHash: "hash",
},
}
_, err := s.preTradeCheck(ctx, now, "uid", domain.Portfolio{
Equity: decimal.NewFromInt(10000),
Cash: decimal.NewFromInt(10000),
}, 0, domain.TradingStatusNormal, now)
if !errors.Is(err, statemachine.ErrSystemHalted) {
t.Fatalf("err=%v, want ErrSystemHalted", err)
}
if !repo.Halted || repo.HaltReason != "pre-trade hard limit breached: max_daily_loss" {
t.Fatalf("halted=%v reason=%q", repo.Halted, repo.HaltReason)
}
if notifier.alerts != 1 {
t.Fatalf("alerts=%d, want 1", notifier.alerts)
}
}
func TestStepSendsMissedDailyReportAfterEntrySignalTime(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
notifier := &countNotifier{}
now := time.Date(2026, 6, 8, 18, 15, 0, 0, time.UTC)
s := Scheduler{
clock: fixedClock{now: now},
cfg: Config{
Mode: domain.ModePaper,
Location: time.UTC,
EntrySignalTime: mustTOD("18:10:00"),
},
sm: statemachine.New(repo, domain.ModePaper),
svc: Services{
Repo: repo,
Notifier: notifier,
AccountIDHash: "hash",
},
}
if err := s.Step(ctx); err != nil {
t.Fatal(err)
}
if notifier.reports != 1 {
t.Fatalf("reports=%d, want catch-up report", notifier.reports)
}
sent, err := repo.WasDailyReportSent(ctx, now, "hash")
if err != nil {
t.Fatal(err)
}
if !sent {
t.Fatal("daily report was not marked as sent")
}
}
func TestSizeReductionRuleCutsSizerAfterBadExpectedErrors(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
+3
View File
@@ -37,6 +37,9 @@ func (s System) Recover(ctx context.Context, reconcile reconciliation.Engine) (d
case domain.StatePlaceEntryOrders, domain.StateMonitorEntryOrders,
domain.StatePlaceExitOrders, domain.StateMonitorExitOrders,
domain.StateHoldOvernight:
if !s.mode.AllowsBrokerOrders() {
return state, nil
}
diffs, err := reconcile.Run(ctx)
if err != nil {
return "", err
+2 -2
View File
@@ -80,7 +80,7 @@ func TestCalendarRecoveryAllowsRestartInsideExitWindow(t *testing.T) {
func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModePaper, false, "", "{}"); err != nil {
if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModeSandbox, false, "", "{}"); err != nil {
t.Fatal(err)
}
if err := repo.UpsertOrder(ctx, domain.Order{
@@ -97,7 +97,7 @@ func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T)
}); err != nil {
t.Fatal(err)
}
system := New(repo, domain.ModePaper)
system := New(repo, domain.ModeSandbox)
state, err := system.Recover(ctx, reconciliation.New(repo, tinvest.NewFakeGateway(), "account", "hash"))
if err == nil {
t.Fatal("expected critical reconciliation error")
+57 -3
View File
@@ -15,6 +15,7 @@ import (
"github.com/shopspring/decimal"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/logging"
@@ -407,6 +408,13 @@ func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID
return domain.Order{}
}
now := time.Now().UTC()
avgFillPrice := decimal.Zero
if resp.GetLotsExecuted() > 0 {
avgFillPrice = money.MoneyValueToDecimal(resp.GetExecutedOrderPrice())
if !avgFillPrice.IsPositive() {
avgFillPrice = limitPrice
}
}
return domain.Order{
ClientOrderID: clientOrderID,
BrokerOrderID: resp.GetOrderId(),
@@ -417,7 +425,7 @@ func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID
LimitPrice: limitPrice,
QuantityLots: resp.GetLotsRequested(),
FilledLots: resp.GetLotsExecuted(),
AvgFillPrice: limitPrice,
AvgFillPrice: avgFillPrice,
Status: mapOrderStatus(resp.GetExecutionReportStatus()),
Commission: money.MoneyValueToDecimal(resp.GetExecutedCommission()),
RawStateJSON: marshalProto(resp),
@@ -438,6 +446,10 @@ func orderFromState(state *pb.OrderState, accountID string) domain.Order {
if state.GetOrderDate() != nil {
orderDate = state.GetOrderDate().AsTime().UTC()
}
avgFillPrice := decimal.Zero
if state.GetLotsExecuted() > 0 {
avgFillPrice = money.MoneyValueToDecimal(state.GetAveragePositionPrice())
}
return domain.Order{
ClientOrderID: state.GetOrderRequestId(),
BrokerOrderID: state.GetOrderId(),
@@ -448,7 +460,7 @@ func orderFromState(state *pb.OrderState, accountID string) domain.Order {
LimitPrice: money.MoneyValueToDecimal(state.GetInitialSecurityPrice()),
QuantityLots: state.GetLotsRequested(),
FilledLots: state.GetLotsExecuted(),
AvgFillPrice: money.MoneyValueToDecimal(state.GetAveragePositionPrice()),
AvgFillPrice: avgFillPrice,
Status: mapOrderStatus(state.GetExecutionReportStatus()),
Commission: money.MoneyValueToDecimal(state.GetExecutedCommission()),
RawStateJSON: marshalProto(state),
@@ -478,10 +490,52 @@ func marshalProto(msg proto.Message) string {
if msg == nil {
return "{}"
}
raw, err := protojson.Marshal(msg)
sanitized := proto.Clone(msg)
clearSensitiveProtoFields(sanitized.ProtoReflect())
raw, err := protojson.Marshal(sanitized)
if err != nil {
fallback, _ := json.Marshal(map[string]string{"marshal_error": err.Error()})
return string(fallback)
}
return string(raw)
}
func clearSensitiveProtoFields(message protoreflect.Message) {
if !message.IsValid() {
return
}
fields := message.Descriptor().Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
if isSensitiveProtoField(field.Name()) {
message.Clear(field)
continue
}
value := message.Get(field)
switch {
case field.IsList():
list := value.List()
if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind {
for j := 0; j < list.Len(); j++ {
clearSensitiveProtoFields(list.Get(j).Message())
}
}
case field.IsMap():
if field.MapValue().Kind() == protoreflect.MessageKind || field.MapValue().Kind() == protoreflect.GroupKind {
value.Map().Range(func(_ protoreflect.MapKey, value protoreflect.Value) bool {
clearSensitiveProtoFields(value.Message())
return true
})
}
case field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind:
if message.Has(field) {
clearSensitiveProtoFields(value.Message())
}
}
}
}
func isSensitiveProtoField(name protoreflect.Name) bool {
normalized := strings.ReplaceAll(strings.ToLower(string(name)), "_", "")
return normalized == "accountid"
}
+39
View File
@@ -0,0 +1,39 @@
package tinvest
import (
"strings"
"testing"
pb "github.com/russianinvestments/invest-api-go-sdk/proto"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
func TestOrderFromPostResponseZeroFillHasZeroAvgPrice(t *testing.T) {
order := orderFromPostResponse(&pb.PostOrderResponse{
OrderId: "broker",
ExecutionReportStatus: pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_NEW,
LotsRequested: 1,
LotsExecuted: 0,
ExecutedOrderPrice: &pb.MoneyValue{Currency: "rub", Units: 100},
InstrumentUid: "uid",
}, "account", "client", domain.SideBuy, decimal.NewFromInt(100))
if !order.AvgFillPrice.IsZero() {
t.Fatalf("avg fill price=%s, want zero for unfilled order", order.AvgFillPrice)
}
}
func TestMarshalProtoRedactsAccountID(t *testing.T) {
raw := marshalProto(&pb.OrderTrades{
OrderId: "order",
AccountId: "plain-account-id",
InstrumentUid: "uid",
})
if strings.Contains(raw, "plain-account-id") || strings.Contains(raw, "accountId") || strings.Contains(raw, "account_id") {
t.Fatalf("raw proto leaked account id: %s", raw)
}
if !strings.Contains(raw, "order") {
t.Fatalf("sanitizer removed non-sensitive data: %s", raw)
}
}