first version

This commit is contained in:
2026-06-07 21:01:40 +00:00
parent ee7167accf
commit f19bab1100
79 changed files with 10355 additions and 145 deletions
+148
View File
@@ -0,0 +1,148 @@
package features
import (
"context"
"fmt"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/timeutil"
)
type PipelineConfig struct {
RollingShort int
RollingLong int
EWMALambda float64
RiskBufferBps decimal.Decimal
EntrySlippageBps decimal.Decimal
ExitSlippageBps decimal.Decimal
CommissionRoundtripBps decimal.Decimal
EntryWindow timeutil.Window
ExitWindow timeutil.Window
Location *time.Location
}
type Pipeline struct {
repo repository.Repository
cfg PipelineConfig
}
func NewPipeline(repo repository.Repository, cfg PipelineConfig) Pipeline {
return Pipeline{repo: repo, cfg: cfg}
}
func (p Pipeline) Recompute(ctx context.Context, instrument domain.Instrument, tradeDate time.Time, spread SpreadResult) (domain.FeatureSet, error) {
from := tradeDate.AddDate(0, 0, -p.cfg.RollingLong-5)
candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, tradeDate)
if err != nil {
return domain.FeatureSet{}, err
}
entryVolume, err := p.intervalVolume(ctx, instrument, tradeDate, p.cfg.EntryWindow)
if err != nil {
return domain.FeatureSet{}, err
}
exitVolume, err := p.intervalVolume(ctx, instrument, tradeDate.AddDate(0, 0, 1), p.cfg.ExitWindow)
if err != nil {
return domain.FeatureSet{}, err
}
feature, err := Compute(instrument, candles, tradeDate, spread, p.cfg, entryVolume, exitVolume)
if err != nil {
return domain.FeatureSet{}, err
}
if err := p.repo.UpsertFeature(ctx, feature); err != nil {
return domain.FeatureSet{}, err
}
return feature, nil
}
func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrument, date time.Time, window timeutil.Window) (decimal.Decimal, error) {
if window.Start.Duration == 0 && window.End.Duration == 0 {
return decimal.Zero, nil
}
loc := p.cfg.Location
if loc == nil {
loc = time.UTC
}
from := window.Start.On(date, loc).UTC()
to := window.End.On(date, loc).UTC()
candles, err := p.repo.ListMinuteCandles(ctx, instrument.InstrumentUID, from, to)
if err != nil {
return decimal.Zero, err
}
return IntervalVolume(candles, instrument.Lot), nil
}
func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate time.Time, spread SpreadResult, cfg PipelineConfig, entryVolume, exitVolume decimal.Decimal) (domain.FeatureSet, error) {
if len(candles) < 2 {
return domain.FeatureSet{}, fmt.Errorf("need at least 2 candles, got %d", len(candles))
}
var overnight []float64
var lastROn decimal.Decimal
var lastRDay decimal.Decimal
for i := 1; i < len(candles); i++ {
rOn, err := OvernightReturn(candles[i].Open, candles[i-1].Close)
if err != nil {
return domain.FeatureSet{}, err
}
rDay, err := IntradayReturn(candles[i].Close, candles[i].Open)
if err != nil {
return domain.FeatureSet{}, err
}
onFloat, _ := rOn.Float64()
overnight = append(overnight, onFloat)
lastROn = rOn
lastRDay = rDay
}
short := Rolling(overnight, cfg.RollingShort, cfg.EWMALambda)
long := Rolling(overnight, cfg.RollingLong, cfg.EWMALambda)
adv := ADV(candles, instrument.Lot, 20)
rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
if !entryVolume.IsPositive() {
entryVolume = adv
}
if !exitVolume.IsPositive() {
exitVolume = adv
}
instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2))
expectedCost := spread.SpreadBps.
Add(cfg.EntrySlippageBps).
Add(cfg.ExitSlippageBps).
Add(cfg.CommissionRoundtripBps).
Add(instrumentCommission).
Add(cfg.RiskBufferBps)
return domain.FeatureSet{
InstrumentUID: instrument.InstrumentUID,
TradeDate: tradeDate,
ROn: lastROn,
RDay: lastRDay,
MuOn60: decimal.NewFromFloat(short.Mean),
MuOn252: decimal.NewFromFloat(long.Mean),
SigmaOn60: decimal.NewFromFloat(short.StdDev),
TStatOn60: decimal.NewFromFloat(short.TStat),
WinOn60: decimal.NewFromFloat(short.WinRate),
EWMAOn: decimal.NewFromFloat(short.EWMA),
SpreadBps: spread.SpreadBps,
HalfSpreadBps: spread.HalfSpreadBps,
TickBps: spread.TickBps,
ADV20: adv,
ExpectedCostBps: expectedCost,
NetEdgeBps: rawEdgeBps.Sub(expectedCost),
EntryIntervalVolume: entryVolume,
ExitIntervalVolume: exitVolume,
CalculatedAt: time.Now().UTC(),
}, nil
}
func IntervalVolume(candles []domain.Candle, lot int64) decimal.Decimal {
if lot <= 0 {
return decimal.Zero
}
total := decimal.Zero
for _, candle := range candles {
total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close))
}
return total
}
+57
View File
@@ -0,0 +1,57 @@
package features
import (
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) {
var candles []domain.Candle
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
for i := 0; i < 6; i++ {
price := decimal.NewFromInt(int64(100 + i))
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: price,
Close: price,
VolumeLots: decimal.NewFromInt(1000),
})
}
got, err := Compute(domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
ExpectedCommissionBpsPerSide: decimal.NewFromInt(1),
}, candles, start.AddDate(0, 0, 5), SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{
RollingShort: 2,
RollingLong: 2,
EWMALambda: 0.08,
RiskBufferBps: decimal.NewFromInt(5),
EntrySlippageBps: decimal.NewFromInt(2),
ExitSlippageBps: decimal.NewFromInt(3),
CommissionRoundtripBps: decimal.NewFromInt(4),
}, decimal.NewFromInt(10000), decimal.NewFromInt(9000))
if err != nil {
t.Fatal(err)
}
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(26)) {
t.Fatalf("expected cost=%s, want 26", got.ExpectedCostBps)
}
if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) {
t.Fatalf("interval volumes were not preserved: %+v", got)
}
}
func TestIntervalVolume(t *testing.T) {
got := IntervalVolume([]domain.Candle{
{Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
{Close: decimal.NewFromInt(101), VolumeLots: decimal.NewFromInt(20)},
}, 2)
if !got.Equal(decimal.NewFromInt(6040)) {
t.Fatalf("interval volume=%s, want 6040", got)
}
}
+207
View File
@@ -0,0 +1,207 @@
package features
import (
"errors"
"math"
"sort"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
)
var ErrInvalidPrice = errors.New("price must be positive")
func OvernightReturn(open, previousClose decimal.Decimal) (decimal.Decimal, error) {
if !open.IsPositive() || !previousClose.IsPositive() {
return decimal.Zero, ErrInvalidPrice
}
return open.Div(previousClose).Sub(decimal.NewFromInt(1)), nil
}
func IntradayReturn(close, open decimal.Decimal) (decimal.Decimal, error) {
if !close.IsPositive() || !open.IsPositive() {
return decimal.Zero, ErrInvalidPrice
}
return close.Div(open).Sub(decimal.NewFromInt(1)), nil
}
func LogReturn(to, from decimal.Decimal) (float64, error) {
if !to.IsPositive() || !from.IsPositive() {
return 0, ErrInvalidPrice
}
ratio, _ := to.Div(from).Float64()
return math.Log(ratio), nil
}
func CumulativeLinear(returns []decimal.Decimal) decimal.Decimal {
total := decimal.NewFromInt(1)
for _, r := range returns {
total = total.Mul(decimal.NewFromInt(1).Add(r))
}
return total.Sub(decimal.NewFromInt(1))
}
func CumulativeLog(logReturns []float64) float64 {
sum := 0.0
for _, r := range logReturns {
sum += r
}
return math.Exp(sum) - 1
}
type RollingResult struct {
Mean float64
StdDev float64
TStat float64
WinRate float64
EWMA float64
Available bool
}
func Rolling(values []float64, window int, lambda float64) RollingResult {
if window <= 0 || len(values) < window {
return RollingResult{}
}
sample := values[len(values)-window:]
mean := Mean(sample)
std := StdDev(sample)
win := WinRate(sample)
ewma := EWMA(values, lambda)
res := RollingResult{
Mean: mean,
StdDev: std,
WinRate: win,
EWMA: ewma,
Available: true,
}
if std > 0 {
res.TStat = mean / std * math.Sqrt(float64(window))
}
return res
}
func Mean(values []float64) float64 {
if len(values) == 0 {
return 0
}
sum := 0.0
for _, value := range values {
sum += value
}
return sum / float64(len(values))
}
func StdDev(values []float64) float64 {
if len(values) < 2 {
return 0
}
mean := Mean(values)
sum := 0.0
for _, value := range values {
diff := value - mean
sum += diff * diff
}
return math.Sqrt(sum / float64(len(values)-1))
}
func WinRate(values []float64) float64 {
if len(values) == 0 {
return 0
}
wins := 0
for _, value := range values {
if value > 0 {
wins++
}
}
return float64(wins) / float64(len(values))
}
func EWMA(values []float64, lambda float64) float64 {
if len(values) == 0 {
return 0
}
if lambda <= 0 || lambda > 1 {
lambda = 0.08
}
ewma := values[0]
for _, value := range values[1:] {
ewma = lambda*value + (1-lambda)*ewma
}
return ewma
}
type SpreadResult struct {
SpreadAbs decimal.Decimal
SpreadBps decimal.Decimal
HalfSpreadBps decimal.Decimal
TickBps decimal.Decimal
Mid decimal.Decimal
}
func Spread(bestBid, bestAsk, tick decimal.Decimal) (SpreadResult, error) {
if !bestBid.IsPositive() || !bestAsk.IsPositive() || bestAsk.LessThanOrEqual(bestBid) {
return SpreadResult{}, ErrInvalidPrice
}
mid := bestAsk.Add(bestBid).Div(decimal.NewFromInt(2))
spreadAbs := bestAsk.Sub(bestBid)
spreadBps, err := money.Bps(spreadAbs, mid)
if err != nil {
return SpreadResult{}, err
}
tickBps := decimal.Zero
if tick.IsPositive() {
tickBps, err = money.Bps(tick, mid)
if err != nil {
return SpreadResult{}, err
}
}
return SpreadResult{
SpreadAbs: spreadAbs,
SpreadBps: spreadBps,
HalfSpreadBps: spreadBps.Div(decimal.NewFromInt(2)),
TickBps: tickBps,
Mid: mid,
}, nil
}
func ADV(candles []domain.Candle, lot int64, window int) decimal.Decimal {
if lot <= 0 || window <= 0 || len(candles) == 0 {
return decimal.Zero
}
sort.Slice(candles, func(i, j int) bool {
return candles[i].TradeDate.Before(candles[j].TradeDate)
})
if len(candles) > window {
candles = candles[len(candles)-window:]
}
total := decimal.Zero
for _, candle := range candles {
total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close))
}
return total.Div(decimal.NewFromInt(int64(len(candles))))
}
func Quantile(values []float64, q float64) float64 {
if len(values) == 0 {
return 0
}
cp := append([]float64(nil), values...)
sort.Float64s(cp)
if q <= 0 {
return cp[0]
}
if q >= 1 {
return cp[len(cp)-1]
}
pos := q * float64(len(cp)-1)
lower := int(math.Floor(pos))
upper := int(math.Ceil(pos))
if lower == upper {
return cp[lower]
}
weight := pos - float64(lower)
return cp[lower]*(1-weight) + cp[upper]*weight
}
+38
View File
@@ -0,0 +1,38 @@
package features
import (
"math"
"testing"
"github.com/shopspring/decimal"
)
func dec(raw string) decimal.Decimal {
v, err := decimal.NewFromString(raw)
if err != nil {
panic(err)
}
return v
}
func TestReturnsAndLogIdentity(t *testing.T) {
rOn, err := OvernightReturn(dec("102"), dec("100"))
if err != nil {
t.Fatal(err)
}
if !rOn.Equal(dec("0.02")) {
t.Fatalf("overnight return=%s", rOn)
}
rDay, err := IntradayReturn(dec("105"), dec("102"))
if err != nil {
t.Fatal(err)
}
if !rDay.Round(10).Equal(dec("0.0294117647")) {
t.Fatalf("intraday return=%s", rDay)
}
linear := CumulativeLinear([]decimal.Decimal{dec("0.01"), dec("-0.02"), dec("0.03")})
logs := []float64{math.Log(1.01), math.Log(0.98), math.Log(1.03)}
if math.Abs(linear.InexactFloat64()-CumulativeLog(logs)) > 1e-10 {
t.Fatalf("linear/log cumulative mismatch")
}
}
+30
View File
@@ -0,0 +1,30 @@
package features
import (
"math"
"testing"
)
func TestRollingStats(t *testing.T) {
values := []float64{0.01, -0.01, 0.02, 0.03}
got := Rolling(values, 4, 0.5)
if !got.Available {
t.Fatal("expected rolling result")
}
if math.Abs(got.Mean-0.0125) > 1e-12 {
t.Fatalf("mean=%f", got.Mean)
}
if math.Abs(got.WinRate-0.75) > 1e-12 {
t.Fatalf("win=%f", got.WinRate)
}
if got.StdDev <= 0 || got.TStat <= 0 {
t.Fatalf("std/tstat invalid: %+v", got)
}
}
func TestRollingSigmaZero(t *testing.T) {
got := Rolling([]float64{0.01, 0.01, 0.01}, 3, 0.08)
if got.StdDev != 0 || got.TStat != 0 {
t.Fatalf("expected zero sigma/tstat, got %+v", got)
}
}
+13
View File
@@ -0,0 +1,13 @@
package features
import "testing"
func TestSpread(t *testing.T) {
got, err := Spread(dec("99"), dec("101"), dec("0.1"))
if err != nil {
t.Fatal(err)
}
if !got.Mid.Equal(dec("100")) || !got.SpreadBps.Equal(dec("200")) || !got.HalfSpreadBps.Equal(dec("100")) || !got.TickBps.Equal(dec("10")) {
t.Fatalf("unexpected spread: %+v", got)
}
}