package features import ( "context" "testing" "time" "github.com/shopspring/decimal" "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/testutil" "overnight-trading-bot/internal/timeutil" ) func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) { var candles []domain.Candle start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) for i := 0; i < 6; i++ { price := decimal.NewFromInt(int64(100 + i)) candles = append(candles, domain.Candle{ InstrumentUID: "uid", TradeDate: start.AddDate(0, 0, i), Open: price, Close: price, VolumeLots: decimal.NewFromInt(1000), }) } got, err := Compute(domain.Instrument{ InstrumentUID: "uid", Lot: 1, ExpectedCommissionBpsPerSide: decimal.NewFromInt(1), }, candles, start.AddDate(0, 0, 5), SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{ RollingShort: 2, RollingLong: 2, EWMALambda: 0.08, RiskBufferBps: decimal.NewFromInt(5), EntrySlippageBps: decimal.NewFromInt(2), ExitSlippageBps: decimal.NewFromInt(3), CommissionRoundtripBps: decimal.NewFromInt(4), }, decimal.NewFromInt(10000), decimal.NewFromInt(9000)) if err != nil { t.Fatal(err) } if !got.ExpectedCostBps.Equal(decimal.NewFromInt(22)) { t.Fatalf("expected cost=%s, want 22", got.ExpectedCostBps) } if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) { t.Fatalf("interval volumes were not preserved: %+v", got) } } func TestComputeExpectedCostFallsBackToConfigCommission(t *testing.T) { candles := flatCandles(time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), 6) got, err := Compute(domain.Instrument{ InstrumentUID: "uid", Lot: 1, }, candles, candles[5].TradeDate, SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{ RollingShort: 2, RollingLong: 2, EWMALambda: 0.08, RiskBufferBps: decimal.NewFromInt(5), EntrySlippageBps: decimal.NewFromInt(2), ExitSlippageBps: decimal.NewFromInt(3), CommissionRoundtripBps: decimal.NewFromInt(4), }, decimal.Zero, decimal.Zero) if err != nil { t.Fatal(err) } if !got.ExpectedCostBps.Equal(decimal.NewFromInt(24)) { t.Fatalf("expected cost=%s, want 24", got.ExpectedCostBps) } } func TestComputeStoresHistoricalQ05Abs(t *testing.T) { start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) returns := []string{"-0.10", "0.01", "0.02", "0.03", "0.04"} candles := []domain.Candle{{ InstrumentUID: "uid", TradeDate: start, Open: decimal.NewFromInt(100), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(1), }} for i, raw := range returns { r, err := decimal.NewFromString(raw) if err != nil { t.Fatal(err) } open := decimal.NewFromInt(100).Mul(decimal.NewFromInt(1).Add(r)) candles = append(candles, domain.Candle{ InstrumentUID: "uid", TradeDate: start.AddDate(0, 0, i+1), Open: open, Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(1), }) } got, err := Compute(domain.Instrument{InstrumentUID: "uid", Lot: 1}, candles, start.AddDate(0, 0, 6), SpreadResult{}, PipelineConfig{ RollingShort: 5, RollingLong: 5, EWMALambda: 0.08, }, decimal.Zero, decimal.Zero) if err != nil { t.Fatal(err) } want := decimal.NewFromFloat(0.078) diff := got.Q05On60Abs.Sub(want) if diff.IsNegative() { diff = diff.Neg() } if diff.GreaterThan(decimal.NewFromFloat(0.000001)) { t.Fatalf("Q05On60Abs=%s, want about %s", got.Q05On60Abs, want) } } func flatCandles(start time.Time, count int) []domain.Candle { candles := make([]domain.Candle, 0, count) for i := 0; i < count; i++ { price := decimal.NewFromInt(int64(100 + i)) candles = append(candles, domain.Candle{ InstrumentUID: "uid", TradeDate: start.AddDate(0, 0, i), Open: price, Close: price, VolumeLots: decimal.NewFromInt(1000), }) } return candles } func TestIntervalVolume(t *testing.T) { got := IntervalVolume([]domain.Candle{ {Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)}, {Close: decimal.NewFromInt(101), VolumeLots: decimal.NewFromInt(20)}, }, 2) if !got.Equal(decimal.NewFromInt(6040)) { t.Fatalf("interval volume=%s, want 6040", got) } } func TestAverageIntervalVolumeUsesExecutionWindowsAcrossDays(t *testing.T) { loc := time.FixedZone("MSK", 3*60*60) window := timeutil.Window{ Start: mustTOD("18:20:00"), End: mustTOD("18:40:00"), } candles := []domain.Candle{ {TradeDate: time.Date(2026, 6, 1, 15, 20, 0, 0, time.UTC), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)}, {TradeDate: time.Date(2026, 6, 1, 15, 50, 0, 0, time.UTC), Close: decimal.NewFromInt(999), VolumeLots: decimal.NewFromInt(999)}, {TradeDate: time.Date(2026, 6, 2, 15, 25, 0, 0, time.UTC), Close: decimal.NewFromInt(200), VolumeLots: decimal.NewFromInt(10)}, } got := AverageIntervalVolume(candles, 1, window, loc) if !got.Equal(decimal.NewFromInt(1500)) { t.Fatalf("average interval volume=%s, want 1500", got) } } func TestRecomputeExcludesTradeDateDailyCandle(t *testing.T) { ctx := context.Background() repo := testutil.NewMemoryRepository() start := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC) var candles []domain.Candle for i := 0; i < 6; i++ { closePrice := decimal.NewFromInt(100) if i == 5 { closePrice = decimal.NewFromInt(100000) } candles = append(candles, domain.Candle{ InstrumentUID: "uid", TradeDate: start.AddDate(0, 0, i), Open: decimal.NewFromInt(100), Close: closePrice, VolumeLots: decimal.NewFromInt(1), }) } if err := repo.UpsertDailyCandles(ctx, candles); err != nil { t.Fatal(err) } pipeline := NewPipeline(repo, PipelineConfig{ RollingShort: 2, RollingLong: 2, EWMALambda: 0.08, }) got, err := pipeline.Recompute(ctx, domain.Instrument{InstrumentUID: "uid", Lot: 1}, start.AddDate(0, 0, 5), SpreadResult{}) if err != nil { t.Fatal(err) } if !got.ADV20.Equal(decimal.NewFromInt(100)) { t.Fatalf("ADV20=%s, want tradeDate candle excluded", got.ADV20) } } func mustTOD(raw string) timeutil.TimeOfDay { tod, err := timeutil.ParseTimeOfDay(raw) if err != nil { panic(err) } return tod }