Files
overnight-trading-bot/internal/features/pipeline_test.go
T

200 lines
6.3 KiB
Go
Raw Normal View History

2026-06-07 21:01:40 +00:00
package features
import (
2026-06-08 07:36:52 +00:00
"context"
2026-06-07 21:01:40 +00:00
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
2026-06-08 07:36:52 +00:00
"overnight-trading-bot/internal/testutil"
2026-06-07 21:51:20 +00:00
"overnight-trading-bot/internal/timeutil"
2026-06-07 21:01:40 +00:00
)
func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) {
var candles []domain.Candle
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
for i := 0; i < 6; i++ {
price := decimal.NewFromInt(int64(100 + i))
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: price,
Close: price,
VolumeLots: decimal.NewFromInt(1000),
})
}
got, err := Compute(domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
ExpectedCommissionBpsPerSide: decimal.NewFromInt(1),
}, candles, start.AddDate(0, 0, 5), SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{
RollingShort: 2,
RollingLong: 2,
EWMALambda: 0.08,
RiskBufferBps: decimal.NewFromInt(5),
EntrySlippageBps: decimal.NewFromInt(2),
ExitSlippageBps: decimal.NewFromInt(3),
CommissionRoundtripBps: decimal.NewFromInt(4),
}, decimal.NewFromInt(10000), decimal.NewFromInt(9000))
if err != nil {
t.Fatal(err)
}
2026-06-08 14:58:56 +00:00
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(22)) {
t.Fatalf("expected cost=%s, want 22", got.ExpectedCostBps)
2026-06-07 21:01:40 +00:00
}
if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) {
t.Fatalf("interval volumes were not preserved: %+v", got)
}
}
2026-06-08 09:03:37 +00:00
func TestComputeExpectedCostFallsBackToConfigCommission(t *testing.T) {
candles := flatCandles(time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), 6)
got, err := Compute(domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
}, candles, candles[5].TradeDate, SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{
RollingShort: 2,
RollingLong: 2,
EWMALambda: 0.08,
RiskBufferBps: decimal.NewFromInt(5),
EntrySlippageBps: decimal.NewFromInt(2),
ExitSlippageBps: decimal.NewFromInt(3),
CommissionRoundtripBps: decimal.NewFromInt(4),
}, decimal.Zero, decimal.Zero)
if err != nil {
t.Fatal(err)
}
2026-06-08 14:58:56 +00:00
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(24)) {
t.Fatalf("expected cost=%s, want 24", got.ExpectedCostBps)
2026-06-08 09:03:37 +00:00
}
}
func TestComputeStoresHistoricalQ05Abs(t *testing.T) {
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
returns := []string{"-0.10", "0.01", "0.02", "0.03", "0.04"}
candles := []domain.Candle{{
InstrumentUID: "uid",
TradeDate: start,
Open: decimal.NewFromInt(100),
Close: decimal.NewFromInt(100),
VolumeLots: decimal.NewFromInt(1),
}}
for i, raw := range returns {
r, err := decimal.NewFromString(raw)
if err != nil {
t.Fatal(err)
}
open := decimal.NewFromInt(100).Mul(decimal.NewFromInt(1).Add(r))
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i+1),
Open: open,
Close: decimal.NewFromInt(100),
VolumeLots: decimal.NewFromInt(1),
})
}
got, err := Compute(domain.Instrument{InstrumentUID: "uid", Lot: 1}, candles, start.AddDate(0, 0, 6), SpreadResult{}, PipelineConfig{
RollingShort: 5,
RollingLong: 5,
EWMALambda: 0.08,
}, decimal.Zero, decimal.Zero)
if err != nil {
t.Fatal(err)
}
want := decimal.NewFromFloat(0.078)
diff := got.Q05On60Abs.Sub(want)
if diff.IsNegative() {
diff = diff.Neg()
}
if diff.GreaterThan(decimal.NewFromFloat(0.000001)) {
t.Fatalf("Q05On60Abs=%s, want about %s", got.Q05On60Abs, want)
}
}
func flatCandles(start time.Time, count int) []domain.Candle {
candles := make([]domain.Candle, 0, count)
for i := 0; i < count; i++ {
price := decimal.NewFromInt(int64(100 + i))
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: price,
Close: price,
VolumeLots: decimal.NewFromInt(1000),
})
}
return candles
}
2026-06-07 21:01:40 +00:00
func TestIntervalVolume(t *testing.T) {
got := IntervalVolume([]domain.Candle{
{Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
{Close: decimal.NewFromInt(101), VolumeLots: decimal.NewFromInt(20)},
}, 2)
if !got.Equal(decimal.NewFromInt(6040)) {
t.Fatalf("interval volume=%s, want 6040", got)
}
}
2026-06-07 21:51:20 +00:00
func TestAverageIntervalVolumeUsesExecutionWindowsAcrossDays(t *testing.T) {
loc := time.FixedZone("MSK", 3*60*60)
window := timeutil.Window{
Start: mustTOD("18:20:00"),
End: mustTOD("18:40:00"),
}
candles := []domain.Candle{
{TradeDate: time.Date(2026, 6, 1, 15, 20, 0, 0, time.UTC), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
{TradeDate: time.Date(2026, 6, 1, 15, 50, 0, 0, time.UTC), Close: decimal.NewFromInt(999), VolumeLots: decimal.NewFromInt(999)},
{TradeDate: time.Date(2026, 6, 2, 15, 25, 0, 0, time.UTC), Close: decimal.NewFromInt(200), VolumeLots: decimal.NewFromInt(10)},
}
got := AverageIntervalVolume(candles, 1, window, loc)
if !got.Equal(decimal.NewFromInt(1500)) {
t.Fatalf("average interval volume=%s, want 1500", got)
}
}
2026-06-08 07:36:52 +00:00
func TestRecomputeExcludesTradeDateDailyCandle(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
start := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)
var candles []domain.Candle
for i := 0; i < 6; i++ {
closePrice := decimal.NewFromInt(100)
if i == 5 {
closePrice = decimal.NewFromInt(100000)
}
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: decimal.NewFromInt(100),
Close: closePrice,
VolumeLots: decimal.NewFromInt(1),
})
}
if err := repo.UpsertDailyCandles(ctx, candles); err != nil {
t.Fatal(err)
}
pipeline := NewPipeline(repo, PipelineConfig{
RollingShort: 2,
RollingLong: 2,
EWMALambda: 0.08,
})
got, err := pipeline.Recompute(ctx, domain.Instrument{InstrumentUID: "uid", Lot: 1}, start.AddDate(0, 0, 5), SpreadResult{})
if err != nil {
t.Fatal(err)
}
if !got.ADV20.Equal(decimal.NewFromInt(100)) {
t.Fatalf("ADV20=%s, want tradeDate candle excluded", got.ADV20)
}
}
2026-06-07 21:51:20 +00:00
func mustTOD(raw string) timeutil.TimeOfDay {
tod, err := timeutil.ParseTimeOfDay(raw)
if err != nil {
panic(err)
}
return tod
}