diff --git a/internal/execution/engine.go b/internal/execution/engine.go index 70054f8..04a89cd 100644 --- a/internal/execution/engine.go +++ b/internal/execution/engine.go @@ -51,6 +51,12 @@ type MonitorConfig struct { RepostCheck func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error } +type repostResult struct { + Current domain.Order + Changed bool + Cancelled domain.Order +} + func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine { return Engine{ mode: mode, @@ -346,13 +352,25 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit aggregate.FilledLots < aggregate.QuantityLots && cfg.Quote != nil if shouldRepost { - next, reposted, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots) + result, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots) + if result.Cancelled.ClientOrderID != "" { + previous := seen[result.Cancelled.ClientOrderID] + aggregate = mergeAggregateFill(aggregate, previous, result.Cancelled) + seen[result.Cancelled.ClientOrderID] = result.Cancelled + if aggregate.FilledLots >= aggregate.QuantityLots { + aggregate.Status = domain.OrderStatusFilled + return aggregate, nil + } + } if err != nil { return aggregate, err } - if reposted { - current = next + if result.Changed { + current = result.Current seen[current.ClientOrderID] = current + aggregate.Status = current.Status + aggregate.UpdatedAt = current.UpdatedAt + aggregate.RawStateJSON = current.RawStateJSON } lastPost = e.nowUTC() continue @@ -405,71 +423,86 @@ func (e *Engine) MonitorOnce(ctx context.Context, order domain.Order, cfg Monito aggregate.FilledLots < aggregate.QuantityLots && cfg.Quote != nil if shouldRepost { - next, reposted, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots) + result, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots) + if result.Cancelled.ClientOrderID != "" { + aggregate = mergeAggregateFill(aggregate, current, result.Cancelled) + if aggregate.FilledLots >= aggregate.QuantityLots { + aggregate.Status = domain.OrderStatusFilled + return aggregate, nil + } + } if err != nil { return aggregate, err } - if reposted { - aggregate.BrokerOrderID = next.BrokerOrderID - aggregate.ClientOrderID = next.ClientOrderID - aggregate.Status = next.Status - aggregate.RawStateJSON = next.RawStateJSON - aggregate.UpdatedAt = next.UpdatedAt + if result.Changed { + aggregate.BrokerOrderID = result.Current.BrokerOrderID + aggregate.ClientOrderID = result.Current.ClientOrderID + aggregate.Status = result.Current.Status + aggregate.RawStateJSON = result.Current.RawStateJSON + aggregate.UpdatedAt = result.Current.UpdatedAt } } return aggregate, nil } -func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConfig, remaining int64) (domain.Order, bool, error) { +func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConfig, remaining int64) (repostResult, error) { if err := e.ensureRepostBudget(ctx, order, cfg.Instrument); err != nil { - return domain.Order{}, false, err + return repostResult{}, err } if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) { - return order, false, nil + return repostResult{Current: order}, nil } book, err := cfg.Quote(ctx, order.InstrumentUID) if err != nil { - return domain.Order{}, false, err + return repostResult{}, err } if cfg.RepostCheck != nil { if err := cfg.RepostCheck(ctx, order, cfg.Instrument, book); err != nil { - return order, false, nil + return repostResult{Current: order}, nil } } if err := e.Cancel(ctx, order); err != nil { - return domain.Order{}, false, err + return repostResult{}, err } cancelled, err := e.waitTerminal(ctx, order, cfg) if err != nil { - return domain.Order{}, false, err + return repostResult{}, err + } + result := repostResult{Current: cancelled, Changed: true, Cancelled: cancelled} + additionalFilled := cancelled.FilledLots - order.FilledLots + if additionalFilled > 0 { + remaining -= additionalFilled } if remaining <= 0 { - cancelled.Status = domain.OrderStatusFilled - return cancelled, true, nil + return result, nil } if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) { - return cancelled, true, nil + return result, nil } book, err = cfg.Quote(ctx, order.InstrumentUID) if err != nil { - return domain.Order{}, false, err + return result, err } if cfg.RepostCheck != nil { if err := cfg.RepostCheck(ctx, cancelled, cfg.Instrument, book); err != nil { - return cancelled, true, nil + return result, nil } } attempt := order.AttemptNo + 1 + var next domain.Order switch order.Side { case domain.SideBuy: - next, err := e.PlaceEntry(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt) - return next, true, err + next, err = e.PlaceEntry(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt) case domain.SideSell: - next, err := e.PlaceExit(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt) - return next, true, err + next, err = e.PlaceExit(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt) default: - return domain.Order{}, false, fmt.Errorf("unsupported side %s", order.Side) + return result, fmt.Errorf("unsupported side %s", order.Side) } + if err != nil { + return result, err + } + result.Current = next + return result, nil } func (e *Engine) waitTerminal(ctx context.Context, order domain.Order, cfg MonitorConfig) (domain.Order, error) { diff --git a/internal/execution/state_test.go b/internal/execution/state_test.go index 2eae243..b898476 100644 --- a/internal/execution/state_test.go +++ b/internal/execution/state_test.go @@ -342,3 +342,164 @@ func TestMonitorOnceDoesNotRepostWhenCheckRejects(t *testing.T) { t.Fatalf("broker orders=%d, want no repost", got) } } + +func TestMonitorOnceRepostAccountsForFillsDuringCancel(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := newCancelFillGateway(2) + engine := NewEngine(domain.ModeSandbox, "account", gateway, repo) + instrument := domain.Instrument{ + InstrumentUID: "uid", + Lot: 1, + MinPriceIncrement: decimal.NewFromInt(1), + FreeOrderLimitPerDay: -1, + } + book := domain.OrderBook{ + InstrumentUID: "uid", + Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}}, + Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}}, + ReceivedAt: time.Now().UTC(), + } + tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) + order, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 5, book, 1, 1) + if err != nil { + t.Fatal(err) + } + order.CreatedAt = time.Now().UTC().Add(-time.Minute) + if err := repo.UpsertOrder(ctx, order); err != nil { + t.Fatal(err) + } + monitored, err := engine.MonitorOnce(ctx, order, MonitorConfig{ + Deadline: time.Now().Add(time.Minute), + PollInterval: time.Millisecond, + MaxAttempts: 2, + RepostAfter: time.Second, + Instrument: instrument, + ImproveTicks: 1, + Quote: func(context.Context, string) (domain.OrderBook, error) { + book.ReceivedAt = time.Now().UTC() + return book, nil + }, + }) + if err != nil { + t.Fatal(err) + } + if monitored.FilledLots != 2 { + t.Fatalf("aggregate filled lots=%d, want cancel fill 2", monitored.FilledLots) + } + if got := len(gateway.posted); got != 2 { + t.Fatalf("broker orders=%d, want initial+repost", got) + } + if got := gateway.posted[1].QuantityLots; got != 3 { + t.Fatalf("repost quantity lots=%d, want remaining 3", got) + } +} + +func TestMonitorOnceKeepsCancelFillWhenRepostPostFails(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := newCancelFillGateway(2) + gateway.failPostAfter = 1 + engine := NewEngine(domain.ModeSandbox, "account", gateway, repo) + instrument := domain.Instrument{ + InstrumentUID: "uid", + Lot: 1, + MinPriceIncrement: decimal.NewFromInt(1), + FreeOrderLimitPerDay: -1, + } + book := domain.OrderBook{ + InstrumentUID: "uid", + Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}}, + Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}}, + ReceivedAt: time.Now().UTC(), + } + tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) + order, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 5, book, 1, 1) + if err != nil { + t.Fatal(err) + } + order.CreatedAt = time.Now().UTC().Add(-time.Minute) + if err := repo.UpsertOrder(ctx, order); err != nil { + t.Fatal(err) + } + monitored, err := engine.MonitorOnce(ctx, order, MonitorConfig{ + Deadline: time.Now().Add(time.Minute), + PollInterval: time.Millisecond, + MaxAttempts: 2, + RepostAfter: time.Second, + Instrument: instrument, + ImproveTicks: 1, + Quote: func(context.Context, string) (domain.OrderBook, error) { + book.ReceivedAt = time.Now().UTC() + return book, nil + }, + }) + if err == nil { + t.Fatal("expected repost post error") + } + if monitored.FilledLots != 2 { + t.Fatalf("aggregate filled lots=%d, want cancel fill 2 despite error", monitored.FilledLots) + } +} + +type cancelFillGateway struct { + orders map[string]domain.Order + posted []domain.Order + fillLotsOnCancel int64 + failPostAfter int +} + +func newCancelFillGateway(fillLotsOnCancel int64) *cancelFillGateway { + return &cancelFillGateway{ + orders: make(map[string]domain.Order), + fillLotsOnCancel: fillLotsOnCancel, + } +} + +func (g *cancelFillGateway) PostLimitOrder(_ context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) { + if g.failPostAfter > 0 && len(g.posted) >= g.failPostAfter { + return domain.Order{}, errors.New("post failed") + } + now := time.Now().UTC() + order := domain.Order{ + ClientOrderID: clientOrderID, + BrokerOrderID: "broker-" + clientOrderID, + AccountIDHash: accountID, + InstrumentUID: instrumentUID, + Side: side, + OrderType: domain.OrderTypeLimit, + LimitPrice: price, + QuantityLots: lots, + Status: domain.OrderStatusSent, + RawStateJSON: "{}", + CreatedAt: now, + UpdatedAt: now, + } + g.orders[order.BrokerOrderID] = order + g.posted = append(g.posted, order) + return order, nil +} + +func (g *cancelFillGateway) CancelOrder(_ context.Context, _ string, orderID string) error { + order, ok := g.orders[orderID] + if !ok { + return tinvest.ErrNotFound + } + fillLots := min(g.fillLotsOnCancel, order.QuantityLots) + if fillLots > order.FilledLots { + order.FilledLots = fillLots + order.AvgFillPrice = order.LimitPrice + } + order.Status = domain.OrderStatusCancelled + order.UpdatedAt = time.Now().UTC() + g.orders[orderID] = order + return nil +} + +func (g *cancelFillGateway) GetOrderState(_ context.Context, _ string, orderID string) (domain.Order, error) { + order, ok := g.orders[orderID] + if !ok { + return domain.Order{}, tinvest.ErrNotFound + } + return order, nil +} diff --git a/internal/risk/manager.go b/internal/risk/manager.go index f253c51..528290e 100644 --- a/internal/risk/manager.go +++ b/internal/risk/manager.go @@ -33,6 +33,7 @@ type ManagerConfig struct { type PreTradeInput struct { Portfolio domain.Portfolio OpenPositions int + ClosingPosition bool DailyPnL decimal.Decimal WeeklyPnL decimal.Decimal MonthlyDrawdownPct decimal.Decimal @@ -91,7 +92,7 @@ func (m Manager) PreTradeCheck(input PreTradeInput) PreTradeResult { return reject("trading_status_unknown_before_order") case input.TradingStatus != domain.TradingStatusNormal: return reject("trading_status_not_normal") - case m.cfg.MaxOpenPositions > 0 && input.OpenPositions >= m.cfg.MaxOpenPositions: + case !input.ClosingPosition && m.cfg.MaxOpenPositions > 0 && input.OpenPositions >= m.cfg.MaxOpenPositions: return reject("max_open_positions") case DailyLossBreached(input.DailyPnL, input.Portfolio.Equity, m.cfg.MaxDailyLossPct): return reject("max_daily_loss") diff --git a/internal/risk/manager_test.go b/internal/risk/manager_test.go new file mode 100644 index 0000000..abdd093 --- /dev/null +++ b/internal/risk/manager_test.go @@ -0,0 +1,28 @@ +package risk + +import ( + "testing" + + "github.com/shopspring/decimal" + + "overnight-trading-bot/internal/domain" +) + +func TestPreTradeClosingPositionBypassesOpenPositionLimit(t *testing.T) { + manager := NewManager(nil, ManagerConfig{MaxOpenPositions: 1}) + input := PreTradeInput{ + Portfolio: domain.Portfolio{Equity: decimal.NewFromInt(1000)}, + OpenPositions: 1, + TradingStatus: domain.TradingStatusNormal, + ClosingPosition: true, + } + result := manager.PreTradeCheck(input) + if !result.Allowed { + t.Fatalf("closing position rejected: %s", result.Reason) + } + input.ClosingPosition = false + result = manager.PreTradeCheck(input) + if result.Allowed || result.Reason != "max_open_positions" { + t.Fatalf("entry result=%+v, want max_open_positions reject", result) + } +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 8072387..5afb4e8 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -492,7 +492,7 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error { } continue } - pre, err := s.preTradeCheck(ctx, now, sig.InstrumentUID, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt) + pre, err := s.preTradeCheck(ctx, now, sig.InstrumentUID, portfolio, projectedOpenPositions, false, tradingStatus, book.ReceivedAt) if err != nil { return err } @@ -567,7 +567,11 @@ func (s *Scheduler) monitorEntryOrders(ctx context.Context, now time.Time) error return err } if monitored.FilledLots > order.FilledLots || monitored.Commission.GreaterThan(order.Commission) { - if err := s.recordEntryFill(ctx, instrument, monitored); err != nil { + fill := entryFillDelta(order, monitored) + if fill.FilledLots <= 0 && fill.Commission.IsZero() { + continue + } + if err := s.recordEntryFill(ctx, instrument, fill); err != nil { return err } } @@ -660,7 +664,7 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error { if err != nil { return err } - pre, err := s.preTradeCheck(ctx, now, pos.InstrumentUID, portfolio, len(positionsList), tradingStatus, book.ReceivedAt) + pre, err := s.preTradeCheck(ctx, now, pos.InstrumentUID, portfolio, len(positionsList), true, tradingStatus, book.ReceivedAt) if err != nil { return err } @@ -1173,7 +1177,7 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order if err != nil { return err } - pre, err := s.preTradeCheck(ctx, now, order.InstrumentUID, portfolio, len(openPositions), tradingStatus, book.ReceivedAt) + pre, err := s.preTradeCheck(ctx, now, order.InstrumentUID, portfolio, len(openPositions), order.Side == domain.SideSell, tradingStatus, book.ReceivedAt) if err != nil { return err } @@ -1194,7 +1198,7 @@ func (s Scheduler) checkEntryInstrumentBeforeOrder(instrument domain.Instrument, return nil } -func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, instrumentUID string, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) { +func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, instrumentUID string, portfolio domain.Portfolio, openPositions int, closingPosition bool, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) { metrics, err := s.riskMetrics(ctx, now, portfolio) if err != nil { if haltErr := s.halt(ctx, "database_unavailable", fmt.Sprintf("pre-trade risk metrics unavailable: %s", err), instrumentUID); haltErr != nil { @@ -1209,6 +1213,7 @@ func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, instrumentU result := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{ Portfolio: portfolio, OpenPositions: openPositions, + ClosingPosition: closingPosition, DailyPnL: metrics.dailyPnL, WeeklyPnL: metrics.weeklyPnL, MonthlyDrawdownPct: metrics.monthlyDrawdownPct, @@ -1509,6 +1514,28 @@ func (s Scheduler) logWarn(msg string, args ...any) { } } +func entryFillDelta(previous, current domain.Order) domain.Order { + fill := current + fill.FilledLots = current.FilledLots - previous.FilledLots + if fill.FilledLots < 0 { + fill.FilledLots = 0 + } + fill.Commission = current.Commission.Sub(previous.Commission) + if fill.Commission.IsNegative() { + fill.Commission = decimal.Zero + } + if fill.FilledLots > 0 { + currentValue := current.AvgFillPrice.Mul(decimal.NewFromInt(current.FilledLots)) + previousValue := previous.AvgFillPrice.Mul(decimal.NewFromInt(previous.FilledLots)) + fill.AvgFillPrice = currentValue.Sub(previousValue).Div(decimal.NewFromInt(fill.FilledLots)) + } + fill.QuantityLots = current.QuantityLots - previous.FilledLots + if fill.QuantityLots < 0 { + fill.QuantityLots = 0 + } + return fill +} + func exitFillDelta(previous, current domain.Order) domain.Order { fill := current fill.FilledLots = current.FilledLots - previous.FilledLots diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index c857c0e..4653036 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -162,6 +162,36 @@ func TestExitFillDeltaUsesOnlyNewlyExecutedLots(t *testing.T) { } } +func TestEntryFillDeltaUsesOnlyNewlyExecutedLots(t *testing.T) { + previous := domain.Order{ + QuantityLots: 10, + FilledLots: 4, + AvgFillPrice: decimal.NewFromInt(100), + Commission: decimal.NewFromFloat(0.40), + InstrumentUID: "uid", + } + current := domain.Order{ + QuantityLots: 10, + FilledLots: 10, + AvgFillPrice: decimal.NewFromInt(106), + Commission: decimal.NewFromFloat(1.00), + InstrumentUID: "uid", + } + fill := entryFillDelta(previous, current) + if fill.FilledLots != 6 { + t.Fatalf("delta filled lots=%d, want 6", fill.FilledLots) + } + if fill.QuantityLots != 6 { + t.Fatalf("delta quantity lots=%d, want 6 remaining target", fill.QuantityLots) + } + if !fill.AvgFillPrice.Equal(decimal.NewFromInt(110)) { + t.Fatalf("delta avg fill price=%s, want 110", fill.AvgFillPrice) + } + if !fill.Commission.Equal(decimal.NewFromFloat(0.60)) { + t.Fatalf("delta commission=%s, want 0.60", fill.Commission) + } +} + func TestHardDeadlineMarksOpenPositionFailedAndHalts(t *testing.T) { ctx := context.Background() repo := testutil.NewMemoryRepository() @@ -344,7 +374,7 @@ func TestPreTradeDailyLossBreachHalts(t *testing.T) { _, err := s.preTradeCheck(ctx, now, "uid", domain.Portfolio{ Equity: decimal.NewFromInt(10000), Cash: decimal.NewFromInt(10000), - }, 0, domain.TradingStatusNormal, now) + }, 0, false, domain.TradingStatusNormal, now) if !errors.Is(err, statemachine.ErrSystemHalted) { t.Fatalf("err=%v, want ErrSystemHalted", err) } diff --git a/internal/tinvest/real.go b/internal/tinvest/real.go index cf90980..dbc0dd1 100644 --- a/internal/tinvest/real.go +++ b/internal/tinvest/real.go @@ -8,11 +8,14 @@ import ( "log/slog" "net/http" "strings" + "sync" "time" "github.com/russianinvestments/invest-api-go-sdk/investgo" pb "github.com/russianinvestments/invest-api-go-sdk/proto" "github.com/shopspring/decimal" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" @@ -35,14 +38,15 @@ type Options struct { type RealGateway struct { client *investgo.Client - instruments *investgo.InstrumentsServiceClient - marketData *investgo.MarketDataServiceClient - orders *investgo.OrdersServiceClient - operations *investgo.OperationsServiceClient - users *investgo.UsersServiceClient + instrumentsPB pb.InstrumentsServiceClient + marketDataPB pb.MarketDataServiceClient + ordersPB pb.OrdersServiceClient + operationsPB pb.OperationsServiceClient + usersPB pb.UsersServiceClient requestTimeout time.Duration retryAttempts int retryBackoff time.Duration + instrumentLots sync.Map } func NewRealGateway(ctx context.Context, opts Options) (*RealGateway, error) { @@ -61,11 +65,11 @@ func NewRealGateway(ctx context.Context, opts Options) (*RealGateway, error) { } return &RealGateway{ client: client, - instruments: client.NewInstrumentsServiceClient(), - marketData: client.NewMarketDataServiceClient(), - orders: client.NewOrdersServiceClient(), - operations: client.NewOperationsServiceClient(), - users: client.NewUsersServiceClient(), + instrumentsPB: pb.NewInstrumentsServiceClient(client.Conn), + marketDataPB: pb.NewMarketDataServiceClient(client.Conn), + ordersPB: pb.NewOrdersServiceClient(client.Conn), + operationsPB: pb.NewOperationsServiceClient(client.Conn), + usersPB: pb.NewUsersServiceClient(client.Conn), requestTimeout: opts.RequestTimeout, retryAttempts: opts.RetryCount, retryBackoff: opts.RetryBackoff, @@ -83,9 +87,13 @@ func (g *RealGateway) GetInstrument(ctx context.Context, ticker, classCode strin if err := ctx.Err(); err != nil { return domain.Instrument{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.EtfResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.EtfResponse, error) { - return g.instruments.EtfByTicker(ticker, classCode) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.EtfResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.EtfResponse, error) { + return g.instrumentsPB.EtfBy(callCtx, &pb.InstrumentRequest{ + IdType: pb.InstrumentIdType_INSTRUMENT_ID_TYPE_TICKER, + ClassCode: &classCode, + Id: ticker, + }) }) }) if err != nil { @@ -95,7 +103,7 @@ func (g *RealGateway) GetInstrument(ctx context.Context, ticker, classCode strin if etf == nil { return domain.Instrument{}, ErrNotFound } - return domain.Instrument{ + instrument := domain.Instrument{ InstrumentUID: etf.GetUid(), Figi: etf.GetFigi(), Ticker: etf.GetTicker(), @@ -106,16 +114,25 @@ func (g *RealGateway) GetInstrument(ctx context.Context, ticker, classCode strin Currency: strings.ToUpper(etf.GetCurrency()), Enabled: etf.GetApiTradeAvailableFlag() && etf.GetBuyAvailableFlag() && etf.GetSellAvailableFlag(), UpdatedAt: time.Now().UTC(), - }, nil + } + g.storeInstrumentLot(instrument) + return instrument, nil } func (g *RealGateway) GetCandles(ctx context.Context, instrumentUID string, interval string, from, to time.Time) ([]domain.Candle, error) { if err := ctx.Err(); err != nil { return nil, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetCandlesResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetCandlesResponse, error) { - return g.marketData.GetCandles(instrumentUID, candleInterval(interval), from, to, pb.GetCandlesRequest_CANDLE_SOURCE_EXCHANGE, 0) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.GetCandlesResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.GetCandlesResponse, error) { + source := pb.GetCandlesRequest_CANDLE_SOURCE_EXCHANGE + return g.marketDataPB.GetCandles(callCtx, &pb.GetCandlesRequest{ + From: investgo.TimeToTimestamp(from), + To: investgo.TimeToTimestamp(to), + Interval: candleInterval(interval), + InstrumentId: &instrumentUID, + CandleSourceType: &source, + }) }) }) if err != nil { @@ -143,9 +160,12 @@ func (g *RealGateway) GetOrderBook(ctx context.Context, instrumentUID string, de if err := ctx.Err(); err != nil { return domain.OrderBook{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrderBookResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderBookResponse, error) { - return g.marketData.GetOrderBook(instrumentUID, depth) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.GetOrderBookResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.GetOrderBookResponse, error) { + return g.marketDataPB.GetOrderBook(callCtx, &pb.GetOrderBookRequest{ + Depth: depth, + InstrumentId: &instrumentUID, + }) }) }) if err != nil { @@ -164,9 +184,11 @@ func (g *RealGateway) GetTradingStatus(ctx context.Context, instrumentUID string if err := ctx.Err(); err != nil { return domain.TradingStatusUnknown, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetTradingStatusResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetTradingStatusResponse, error) { - return g.marketData.GetTradingStatus(instrumentUID) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.GetTradingStatusResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.GetTradingStatusResponse, error) { + return g.marketDataPB.GetTradingStatus(callCtx, &pb.GetTradingStatusRequest{ + InstrumentId: &instrumentUID, + }) }) }) if err != nil { @@ -192,9 +214,9 @@ func (g *RealGateway) PostLimitOrder(ctx context.Context, accountID, instrumentU if err != nil { return domain.Order{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PostOrderResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { - return g.orders.PostOrder(&investgo.PostOrderRequest{ + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.PostOrderResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.PostOrderResponse, error) { + return g.ordersPB.PostOrder(callCtx, &pb.PostOrderRequest{ InstrumentId: instrumentUID, Quantity: lots, Price: quotation, @@ -210,16 +232,19 @@ func (g *RealGateway) PostLimitOrder(ctx context.Context, accountID, instrumentU if err != nil { return domain.Order{}, err } - return orderFromPostResponse(resp.PostOrderResponse, accountID, clientOrderID, side, price), nil + return orderFromPostResponse(resp, accountID, clientOrderID, side, price), nil } func (g *RealGateway) CancelOrder(ctx context.Context, accountID, orderID string) error { if err := ctx.Err(); err != nil { return err } - _, err := requestWithTimeout(ctx, g.requestTimeout, func() (struct{}, error) { - return struct{}{}, withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error { - _, err := g.orders.CancelOrder(accountID, orderID, nil) + _, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (struct{}, error) { + return struct{}{}, withRetry(callCtx, g.retryAttempts, g.retryBackoff, func() error { + _, err := g.ordersPB.CancelOrder(callCtx, &pb.CancelOrderRequest{ + AccountId: accountID, + OrderId: orderID, + }) return err }) }) @@ -230,24 +255,28 @@ func (g *RealGateway) GetOrderState(ctx context.Context, accountID, orderID stri if err := ctx.Err(); err != nil { return domain.Order{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrderStateResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) { - return g.orders.GetOrderState(accountID, orderID, pb.PriceType_PRICE_TYPE_CURRENCY, nil) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.OrderState, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.OrderState, error) { + return g.ordersPB.GetOrderState(callCtx, &pb.GetOrderStateRequest{ + AccountId: accountID, + OrderId: orderID, + PriceType: pb.PriceType_PRICE_TYPE_CURRENCY, + }) }) }) if err != nil { return domain.Order{}, err } - return orderFromState(resp.OrderState, accountID), nil + return orderFromState(resp, accountID), nil } func (g *RealGateway) GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error) { if err := ctx.Err(); err != nil { return nil, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrdersResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) { - return g.orders.GetOrders(accountID, nil) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.GetOrdersResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.GetOrdersResponse, error) { + return g.ordersPB.GetOrders(callCtx, &pb.GetOrdersRequest{AccountId: accountID}) }) }) if err != nil { @@ -265,34 +294,38 @@ func (g *RealGateway) GetPortfolio(ctx context.Context, accountID string) (domai if err := ctx.Err(); err != nil { return domain.Portfolio{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PortfolioResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) { - return g.operations.GetPortfolio(accountID, pb.PortfolioRequest_RUB) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.PortfolioResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.PortfolioResponse, error) { + currency := pb.PortfolioRequest_RUB + return g.operationsPB.GetPortfolio(callCtx, &pb.PortfolioRequest{ + AccountId: accountID, + Currency: ¤cy, + }) }) }) if err != nil { return domain.Portfolio{}, err } - return portfolioFromResponse(resp.PortfolioResponse) + return portfolioFromResponse(resp, g.lotForInstrument) } func (g *RealGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) { if err := ctx.Err(); err != nil { return nil, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.OperationsResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) { - return g.operations.GetOperations(&investgo.GetOperationsRequest{ + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.OperationsResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.OperationsResponse, error) { + return g.operationsPB.GetOperations(callCtx, &pb.OperationsRequest{ AccountId: accountID, - From: from, - To: to, + From: investgo.TimeToTimestamp(from), + To: investgo.TimeToTimestamp(to), }) }) }) if err != nil { return nil, err } - return operationsFromResponse(resp.OperationsResponse), nil + return operationsFromResponse(resp), nil } func operationsFromResponse(resp *pb.OperationsResponse) []domain.Operation { @@ -312,13 +345,13 @@ func operationsFromResponse(resp *pb.OperationsResponse) []domain.Operation { return out } -func portfolioFromResponse(resp *pb.PortfolioResponse) (domain.Portfolio, error) { +func portfolioFromResponse(resp *pb.PortfolioResponse, lotForInstrument func(string) int64) (domain.Portfolio, error) { positions := resp.GetPositions() holdings := make([]domain.Holding, 0, len(positions)) for _, position := range positions { holdings = append(holdings, domain.Holding{ InstrumentUID: position.GetInstrumentUid(), - QuantityLots: portfolioQuantityLots(position), + QuantityLots: portfolioQuantityLots(position, portfolioPositionLot(position, lotForInstrument)), AveragePrice: money.MoneyValueToDecimal(position.GetAveragePositionPrice()), MarketValue: money.MoneyValueToDecimal(position.GetCurrentPrice()).Mul(money.QuotationToDecimal(position.GetQuantity())), }) @@ -343,15 +376,20 @@ func (g *RealGateway) GetServerTime(ctx context.Context) (time.Time, error) { if err := ctx.Err(); err != nil { return time.Time{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetInfoResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetInfoResponse, error) { - return g.users.GetInfo() + header, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (metadata.MD, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (metadata.MD, error) { + var header, trailer metadata.MD + _, err := g.usersPB.GetInfo(callCtx, &pb.GetInfoRequest{}, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + return trailer, err + } + return header, nil }) }) if err != nil { return time.Time{}, err } - if serverTime, ok := serverTimeFromHeader(resp.Header); ok { + if serverTime, ok := serverTimeFromHeader(header); ok { return serverTime, nil } return time.Time{}, errors.New("server time is unavailable in response metadata") @@ -376,14 +414,47 @@ func rubMoneyValueToDecimal(value *pb.MoneyValue) (decimal.Decimal, error) { return money.MoneyValueToDecimal(value), nil } -func portfolioQuantityLots(position *pb.PortfolioPosition) int64 { +func portfolioPositionLot(position *pb.PortfolioPosition, lotForInstrument func(string) int64) int64 { + if position == nil || lotForInstrument == nil { + return 0 + } + return lotForInstrument(position.GetInstrumentUid()) +} + +func portfolioQuantityLots(position *pb.PortfolioPosition, lot int64) int64 { if position == nil { return 0 } if lots, ok := portfolioDeprecatedQuantityLots(position); ok { return lots.IntPart() } - return money.QuotationToDecimal(position.GetQuantity()).IntPart() + quantity := money.QuotationToDecimal(position.GetQuantity()) + if lot > 0 { + return quantity.Div(decimal.NewFromInt(lot)).IntPart() + } + return quantity.IntPart() +} + +func (g *RealGateway) storeInstrumentLot(instrument domain.Instrument) { + if instrument.InstrumentUID == "" || instrument.Lot <= 0 { + return + } + g.instrumentLots.Store(instrument.InstrumentUID, instrument.Lot) +} + +func (g *RealGateway) lotForInstrument(instrumentUID string) int64 { + if instrumentUID == "" { + return 0 + } + value, ok := g.instrumentLots.Load(instrumentUID) + if !ok { + return 0 + } + lot, ok := value.(int64) + if !ok { + return 0 + } + return lot } func portfolioDeprecatedQuantityLots(position *pb.PortfolioPosition) (decimal.Decimal, bool) { diff --git a/internal/tinvest/real_test.go b/internal/tinvest/real_test.go index 4dbfd89..887d2e5 100644 --- a/internal/tinvest/real_test.go +++ b/internal/tinvest/real_test.go @@ -37,3 +37,29 @@ func TestMarshalProtoRedactsAccountID(t *testing.T) { t.Fatalf("sanitizer removed non-sensitive data: %s", raw) } } + +func TestPortfolioFromResponseConvertsUnitsToLots(t *testing.T) { + portfolio, err := portfolioFromResponse(&pb.PortfolioResponse{ + Positions: []*pb.PortfolioPosition{ + { + InstrumentUid: "uid", + Quantity: &pb.Quotation{Units: 20}, + CurrentPrice: &pb.MoneyValue{Currency: "rub", Units: 10}, + }, + }, + }, func(instrumentUID string) int64 { + if instrumentUID == "uid" { + return 10 + } + return 0 + }) + if err != nil { + t.Fatal(err) + } + if got := portfolio.Holdings[0].QuantityLots; got != 2 { + t.Fatalf("quantity lots=%d, want 2", got) + } + if !portfolio.Holdings[0].MarketValue.Equal(decimal.NewFromInt(200)) { + t.Fatalf("market value=%s, want 200", portfolio.Holdings[0].MarketValue) + } +} diff --git a/internal/tinvest/retry.go b/internal/tinvest/retry.go index 5337966..d570f9f 100644 --- a/internal/tinvest/retry.go +++ b/internal/tinvest/retry.go @@ -63,26 +63,11 @@ func retryValue[T any](ctx context.Context, attempts int, interval time.Duration return out, nil } -func requestWithTimeout[T any](ctx context.Context, timeout time.Duration, fn func() (T, error)) (T, error) { +func requestWithTimeout[T any](ctx context.Context, timeout time.Duration, fn func(context.Context) (T, error)) (T, error) { if timeout <= 0 { - return fn() + return fn(ctx) } callCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - type result struct { - value T - err error - } - done := make(chan result, 1) - go func() { - value, err := fn() - done <- result{value: value, err: err} - }() - select { - case res := <-done: - return res.value, res.err - case <-callCtx.Done(): - var zero T - return zero, callCtx.Err() - } + return fn(callCtx) } diff --git a/internal/tinvest/retry_test.go b/internal/tinvest/retry_test.go index 621c5c5..fe75dbf 100644 --- a/internal/tinvest/retry_test.go +++ b/internal/tinvest/retry_test.go @@ -25,9 +25,9 @@ func TestWithRetryRetriesUntilSuccess(t *testing.T) { } func TestRequestWithTimeoutReturnsDeadline(t *testing.T) { - _, err := requestWithTimeout(context.Background(), time.Millisecond, func() (int, error) { - time.Sleep(50 * time.Millisecond) - return 1, nil + _, err := requestWithTimeout(context.Background(), time.Millisecond, func(ctx context.Context) (int, error) { + <-ctx.Done() + return 0, ctx.Err() }) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("err=%v, want DeadlineExceeded", err) diff --git a/internal/tinvest/sandbox.go b/internal/tinvest/sandbox.go index 86b9e92..4db9030 100644 --- a/internal/tinvest/sandbox.go +++ b/internal/tinvest/sandbox.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/russianinvestments/invest-api-go-sdk/investgo" pb "github.com/russianinvestments/invest-api-go-sdk/proto" "github.com/shopspring/decimal" + "google.golang.org/protobuf/types/known/timestamppb" "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/money" @@ -16,7 +16,7 @@ const sandboxEndpoint = "sandbox-invest-public-api.tinkoff.ru:443" type SandboxGateway struct { *RealGateway - sandbox *investgo.SandboxServiceClient + sandboxPB pb.SandboxServiceClient } func NewSandboxGateway(ctx context.Context, opts Options) (*SandboxGateway, error) { @@ -27,7 +27,7 @@ func NewSandboxGateway(ctx context.Context, opts Options) (*SandboxGateway, erro } return &SandboxGateway{ RealGateway: realGateway, - sandbox: realGateway.client.NewSandboxServiceClient(), + sandboxPB: pb.NewSandboxServiceClient(realGateway.client.Conn), }, nil } @@ -43,9 +43,9 @@ func (g *SandboxGateway) PostLimitOrder(ctx context.Context, accountID, instrume if err != nil { return domain.Order{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PostOrderResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { - return g.sandbox.PostSandboxOrder(&investgo.PostOrderRequest{ + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.PostOrderResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.PostOrderResponse, error) { + return g.sandboxPB.PostSandboxOrder(callCtx, &pb.PostOrderRequest{ InstrumentId: instrumentUID, Quantity: lots, Price: quotation, @@ -61,16 +61,19 @@ func (g *SandboxGateway) PostLimitOrder(ctx context.Context, accountID, instrume if err != nil { return domain.Order{}, err } - return orderFromPostResponse(resp.PostOrderResponse, accountID, clientOrderID, side, price), nil + return orderFromPostResponse(resp, accountID, clientOrderID, side, price), nil } func (g *SandboxGateway) CancelOrder(ctx context.Context, accountID, orderID string) error { if err := ctx.Err(); err != nil { return err } - _, err := requestWithTimeout(ctx, g.requestTimeout, func() (struct{}, error) { - return struct{}{}, withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error { - _, err := g.sandbox.CancelSandboxOrder(accountID, orderID) + _, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (struct{}, error) { + return struct{}{}, withRetry(callCtx, g.retryAttempts, g.retryBackoff, func() error { + _, err := g.sandboxPB.CancelSandboxOrder(callCtx, &pb.CancelOrderRequest{ + AccountId: accountID, + OrderId: orderID, + }) return err }) }) @@ -81,24 +84,28 @@ func (g *SandboxGateway) GetOrderState(ctx context.Context, accountID, orderID s if err := ctx.Err(); err != nil { return domain.Order{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrderStateResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) { - return g.sandbox.GetSandboxOrderState(accountID, orderID) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.OrderState, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.OrderState, error) { + return g.sandboxPB.GetSandboxOrderState(callCtx, &pb.GetOrderStateRequest{ + AccountId: accountID, + OrderId: orderID, + PriceType: pb.PriceType_PRICE_TYPE_CURRENCY, + }) }) }) if err != nil { return domain.Order{}, err } - return orderFromState(resp.OrderState, accountID), nil + return orderFromState(resp, accountID), nil } func (g *SandboxGateway) GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error) { if err := ctx.Err(); err != nil { return nil, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.GetOrdersResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) { - return g.sandbox.GetSandboxOrders(accountID) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.GetOrdersResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.GetOrdersResponse, error) { + return g.sandboxPB.GetSandboxOrders(callCtx, &pb.GetOrdersRequest{AccountId: accountID}) }) }) if err != nil { @@ -116,32 +123,36 @@ func (g *SandboxGateway) GetPortfolio(ctx context.Context, accountID string) (do if err := ctx.Err(); err != nil { return domain.Portfolio{}, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.PortfolioResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) { - return g.sandbox.GetSandboxPortfolio(accountID, pb.PortfolioRequest_RUB) + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.PortfolioResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.PortfolioResponse, error) { + currency := pb.PortfolioRequest_RUB + return g.sandboxPB.GetSandboxPortfolio(callCtx, &pb.PortfolioRequest{ + AccountId: accountID, + Currency: ¤cy, + }) }) }) if err != nil { return domain.Portfolio{}, err } - return portfolioFromResponse(resp.PortfolioResponse) + return portfolioFromResponse(resp, g.lotForInstrument) } func (g *SandboxGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) { if err := ctx.Err(); err != nil { return nil, err } - resp, err := requestWithTimeout(ctx, g.requestTimeout, func() (*investgo.OperationsResponse, error) { - return retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) { - return g.sandbox.GetSandboxOperations(&investgo.GetOperationsRequest{ + resp, err := requestWithTimeout(ctx, g.requestTimeout, func(callCtx context.Context) (*pb.OperationsResponse, error) { + return retryValue(callCtx, g.retryAttempts, g.retryBackoff, func() (*pb.OperationsResponse, error) { + return g.sandboxPB.GetSandboxOperations(callCtx, &pb.OperationsRequest{ AccountId: accountID, - From: from, - To: to, + From: timestamppb.New(from), + To: timestamppb.New(to), }) }) }) if err != nil { return nil, err } - return operationsFromResponse(resp.OperationsResponse), nil + return operationsFromResponse(resp), nil }