first version
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
migratemysql "github.com/golang-migrate/migrate/v4/database/mysql"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
|
||||
"overnight-trading-bot/internal/repository/migrations"
|
||||
)
|
||||
|
||||
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
driver, err := migratemysql.WithInstance(db, &migratemysql.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mysql migration driver: %w", err)
|
||||
}
|
||||
source, err := iofs.New(migrations.FS, ".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iofs migration source: %w", err)
|
||||
}
|
||||
m, err := migrate.NewWithInstance("iofs", source, "mysql", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create migrate instance: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = m.Close()
|
||||
}()
|
||||
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("apply migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RollbackAll(db *sql.DB) error {
|
||||
driver, err := migratemysql.WithInstance(db, &migratemysql.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mysql migration driver: %w", err)
|
||||
}
|
||||
source, err := iofs.New(migrations.FS, ".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iofs migration source: %w", err)
|
||||
}
|
||||
m, err := migrate.NewWithInstance("iofs", source, "mysql", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create migrate instance: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = m.Close()
|
||||
}()
|
||||
if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("rollback migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,790 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/repository"
|
||||
)
|
||||
|
||||
var _ repository.Repository = (*Repository)(nil)
|
||||
|
||||
type Repository struct {
|
||||
db *sqlx.DB
|
||||
tx *sqlx.Tx
|
||||
}
|
||||
|
||||
func NewRepository(db *sqlx.DB) *Repository {
|
||||
return &Repository{db: db}
|
||||
}
|
||||
|
||||
func (r *Repository) RunInTx(ctx context.Context, fn func(ctx context.Context, repo repository.Repository) error) error {
|
||||
if r.tx != nil {
|
||||
return fn(ctx, r)
|
||||
}
|
||||
tx, err := r.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := &Repository{db: r.db, tx: tx}
|
||||
if err := fn(ctx, txRepo); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("%w; rollback: %v", err, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *Repository) execer() sqlx.ExtContext {
|
||||
if r.tx != nil {
|
||||
return r.tx
|
||||
}
|
||||
return r.db
|
||||
}
|
||||
|
||||
func (r *Repository) selectContext(ctx context.Context, dest any, query string, args ...any) error {
|
||||
if r.tx != nil {
|
||||
return r.tx.SelectContext(ctx, dest, query, args...)
|
||||
}
|
||||
return r.db.SelectContext(ctx, dest, query, args...)
|
||||
}
|
||||
|
||||
func (r *Repository) getContext(ctx context.Context, dest any, query string, args ...any) error {
|
||||
if r.tx != nil {
|
||||
return r.tx.GetContext(ctx, dest, query, args...)
|
||||
}
|
||||
return r.db.GetContext(ctx, dest, query, args...)
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertInstrument(ctx context.Context, instrument domain.Instrument) error {
|
||||
if instrument.UpdatedAt.IsZero() {
|
||||
instrument.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO instruments (
|
||||
instrument_uid, figi, ticker, class_code, name, lot, min_price_increment, currency,
|
||||
enabled, fund_type, expected_commission_bps_per_side, free_order_limit_per_day,
|
||||
quarantine, quarantine_reason, exclude_reason, updated_at
|
||||
) VALUES (
|
||||
:instrument_uid, :figi, :ticker, :class_code, :name, :lot, :min_price_increment, :currency,
|
||||
:enabled, :fund_type, :expected_commission_bps_per_side, :free_order_limit_per_day,
|
||||
:quarantine, :quarantine_reason, :exclude_reason, :updated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
instrument_uid=VALUES(instrument_uid),
|
||||
figi=VALUES(figi),
|
||||
name=VALUES(name),
|
||||
lot=VALUES(lot),
|
||||
min_price_increment=VALUES(min_price_increment),
|
||||
currency=VALUES(currency),
|
||||
enabled=VALUES(enabled),
|
||||
fund_type=VALUES(fund_type),
|
||||
expected_commission_bps_per_side=VALUES(expected_commission_bps_per_side),
|
||||
free_order_limit_per_day=VALUES(free_order_limit_per_day),
|
||||
quarantine=VALUES(quarantine),
|
||||
quarantine_reason=VALUES(quarantine_reason),
|
||||
exclude_reason=VALUES(exclude_reason),
|
||||
updated_at=VALUES(updated_at)`, instrumentRowFromDomain(instrument))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ReplaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
|
||||
if oldInstrumentUID == "" || oldInstrumentUID == instrument.InstrumentUID {
|
||||
return r.UpsertInstrument(ctx, instrument)
|
||||
}
|
||||
return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
|
||||
txRepo, ok := repo.(*Repository)
|
||||
if !ok {
|
||||
return errors.New("unexpected repository implementation")
|
||||
}
|
||||
return txRepo.replaceInstrument(ctx, oldInstrumentUID, instrument)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Repository) replaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
|
||||
if instrument.UpdatedAt.IsZero() {
|
||||
instrument.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
exists, err := r.instrumentExists(ctx, instrument.InstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
if err := r.mergeInstrumentUID(ctx, oldInstrumentUID, instrument.InstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.UpsertInstrument(ctx, instrument)
|
||||
}
|
||||
result, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
UPDATE instruments SET
|
||||
instrument_uid=:instrument_uid,
|
||||
figi=:figi,
|
||||
ticker=:ticker,
|
||||
class_code=:class_code,
|
||||
name=:name,
|
||||
lot=:lot,
|
||||
min_price_increment=:min_price_increment,
|
||||
currency=:currency,
|
||||
enabled=:enabled,
|
||||
fund_type=:fund_type,
|
||||
expected_commission_bps_per_side=:expected_commission_bps_per_side,
|
||||
free_order_limit_per_day=:free_order_limit_per_day,
|
||||
quarantine=:quarantine,
|
||||
quarantine_reason=:quarantine_reason,
|
||||
exclude_reason=:exclude_reason,
|
||||
updated_at=:updated_at
|
||||
WHERE instrument_uid=:old_instrument_uid`, replaceInstrumentRowFromDomain(oldInstrumentUID, instrument))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return r.UpsertInstrument(ctx, instrument)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) instrumentExists(ctx context.Context, instrumentUID string) (bool, error) {
|
||||
var count int
|
||||
if err := r.getContext(ctx, &count, `SELECT COUNT(*) FROM instruments WHERE instrument_uid=?`, instrumentUID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *Repository) mergeInstrumentUID(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
if oldInstrumentUID == newInstrumentUID {
|
||||
return nil
|
||||
}
|
||||
if err := r.mergeDailyCandles(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeMinuteCandles(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeFeatures(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeSignals(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.mergeFreeOrders(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, table := range []string{"orders", "positions", "risk_events"} {
|
||||
if _, err := r.execer().ExecContext(ctx, fmt.Sprintf(`UPDATE %s SET instrument_uid=? WHERE instrument_uid=?`, table), newInstrumentUID, oldInstrumentUID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `DELETE FROM instruments WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeDailyCandles(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO candles_daily (instrument_uid, trade_date, open, high, low, close, volume_lots, source, loaded_at)
|
||||
SELECT ?, trade_date, open, high, low, close, volume_lots, source, loaded_at
|
||||
FROM candles_daily WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM candles_daily WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeMinuteCandles(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO candles_minute (instrument_uid, ts, open, high, low, close, volume_lots, source, loaded_at)
|
||||
SELECT ?, ts, open, high, low, close, volume_lots, source, loaded_at
|
||||
FROM candles_minute WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM candles_minute WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeFeatures(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO features (
|
||||
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
|
||||
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
|
||||
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
|
||||
exit_interval_volume, calculated_at
|
||||
)
|
||||
SELECT
|
||||
?, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
|
||||
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
|
||||
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
|
||||
exit_interval_volume, calculated_at
|
||||
FROM features WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
|
||||
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
|
||||
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
|
||||
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
|
||||
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
|
||||
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
|
||||
net_edge_bps=VALUES(net_edge_bps), entry_interval_volume=VALUES(entry_interval_volume),
|
||||
exit_interval_volume=VALUES(exit_interval_volume), calculated_at=VALUES(calculated_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM features WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeSignals(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO signals (
|
||||
trade_date, instrument_uid, decision, score, net_edge_bps, target_notional,
|
||||
target_lots, reject_reason, context_json, created_at
|
||||
)
|
||||
SELECT trade_date, ?, decision, score, net_edge_bps, target_notional,
|
||||
target_lots, reject_reason, context_json, created_at
|
||||
FROM signals WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE
|
||||
decision=VALUES(decision), score=VALUES(score), net_edge_bps=VALUES(net_edge_bps),
|
||||
target_notional=VALUES(target_notional), target_lots=VALUES(target_lots),
|
||||
reject_reason=VALUES(reject_reason), context_json=VALUES(context_json),
|
||||
created_at=VALUES(created_at)`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM signals WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) mergeFreeOrders(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO free_order_counters (trade_date, instrument_uid, orders_sent)
|
||||
SELECT trade_date, ?, orders_sent FROM free_order_counters WHERE instrument_uid=?
|
||||
ON DUPLICATE KEY UPDATE orders_sent=GREATEST(orders_sent, VALUES(orders_sent))`, newInstrumentUID, oldInstrumentUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.execer().ExecContext(ctx, `DELETE FROM free_order_counters WHERE instrument_uid=?`, oldInstrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListInstruments(ctx context.Context, includeDisabled bool) ([]domain.Instrument, error) {
|
||||
query := `SELECT * FROM instruments`
|
||||
if !includeDisabled {
|
||||
query += ` WHERE enabled=1`
|
||||
}
|
||||
query += ` ORDER BY ticker`
|
||||
var rows []instrumentRow
|
||||
if err := r.selectContext(ctx, &rows, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Instrument, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) QuarantineInstrument(ctx context.Context, instrumentUID, reason string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
UPDATE instruments SET quarantine=1, quarantine_reason=?, updated_at=UTC_TIMESTAMP(3)
|
||||
WHERE instrument_uid=?`, reason, instrumentUID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertDailyCandles(ctx context.Context, candles []domain.Candle) error {
|
||||
for _, candle := range candles {
|
||||
if candle.LoadedAt.IsZero() {
|
||||
candle.LoadedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO candles_daily (
|
||||
instrument_uid, trade_date, open, high, low, close, volume_lots, source, loaded_at
|
||||
) VALUES (
|
||||
:instrument_uid, :trade_date, :open, :high, :low, :close, :volume_lots, :source, :loaded_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, candleRowFromDomain(candle))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListDailyCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
|
||||
var rows []candleRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM candles_daily
|
||||
WHERE instrument_uid=? AND trade_date BETWEEN ? AND ?
|
||||
ORDER BY trade_date`, instrumentUID, dateOnly(from), dateOnly(to)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Candle, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertMinuteCandles(ctx context.Context, candles []domain.Candle) error {
|
||||
for _, candle := range candles {
|
||||
if candle.LoadedAt.IsZero() {
|
||||
candle.LoadedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO candles_minute (
|
||||
instrument_uid, ts, open, high, low, close, volume_lots, source, loaded_at
|
||||
) VALUES (
|
||||
:instrument_uid, :trade_date, :open, :high, :low, :close, :volume_lots, :source, :loaded_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
|
||||
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, candleRowFromDomain(candle))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListMinuteCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
|
||||
var rows []candleRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT instrument_uid, ts AS trade_date, open, high, low, close, volume_lots, source, loaded_at
|
||||
FROM candles_minute
|
||||
WHERE instrument_uid=? AND ts BETWEEN ? AND ?
|
||||
ORDER BY ts`, instrumentUID, from.UTC(), to.UTC()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Candle, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertFeature(ctx context.Context, feature domain.FeatureSet) error {
|
||||
if feature.CalculatedAt.IsZero() {
|
||||
feature.CalculatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO features (
|
||||
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
|
||||
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
|
||||
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
|
||||
exit_interval_volume, calculated_at
|
||||
) VALUES (
|
||||
:instrument_uid, :trade_date, :r_on, :r_day, :mu_on_60, :mu_on_252, :sigma_on_60,
|
||||
:tstat_on_60, :win_on_60, :ewma_on, :spread_bps, :half_spread_bps, :tick_bps,
|
||||
:adv_20, :expected_cost_bps, :net_edge_bps, :entry_interval_volume,
|
||||
:exit_interval_volume, :calculated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
|
||||
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
|
||||
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
|
||||
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
|
||||
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
|
||||
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
|
||||
net_edge_bps=VALUES(net_edge_bps), entry_interval_volume=VALUES(entry_interval_volume),
|
||||
exit_interval_volume=VALUES(exit_interval_volume), calculated_at=VALUES(calculated_at)`, featureRowFromDomain(feature))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) GetFeature(ctx context.Context, instrumentUID string, tradeDate time.Time) (domain.FeatureSet, error) {
|
||||
var row featureRow
|
||||
if err := r.getContext(ctx, &row, `SELECT * FROM features WHERE instrument_uid=? AND trade_date=?`, instrumentUID, dateOnly(tradeDate)); err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
return row.domain(), nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertSignal(ctx context.Context, signal domain.Signal) error {
|
||||
if signal.CreatedAt.IsZero() {
|
||||
signal.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if signal.ContextJSON == "" {
|
||||
signal.ContextJSON = "{}"
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO signals (
|
||||
trade_date, instrument_uid, decision, score, net_edge_bps, target_notional,
|
||||
target_lots, reject_reason, context_json, created_at
|
||||
) VALUES (
|
||||
:trade_date, :instrument_uid, :decision, :score, :net_edge_bps, :target_notional,
|
||||
:target_lots, :reject_reason, :context_json, :created_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
decision=VALUES(decision), score=VALUES(score), net_edge_bps=VALUES(net_edge_bps),
|
||||
target_notional=VALUES(target_notional), target_lots=VALUES(target_lots),
|
||||
reject_reason=VALUES(reject_reason), context_json=VALUES(context_json),
|
||||
created_at=VALUES(created_at)`, signalRowFromDomain(signal))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListSignals(ctx context.Context, tradeDate time.Time) ([]domain.Signal, error) {
|
||||
var rows []signalRow
|
||||
if err := r.selectContext(ctx, &rows, `SELECT * FROM signals WHERE trade_date=? ORDER BY id`, dateOnly(tradeDate)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Signal, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertOrder(ctx context.Context, order domain.Order) error {
|
||||
now := time.Now().UTC()
|
||||
if order.CreatedAt.IsZero() {
|
||||
order.CreatedAt = now
|
||||
}
|
||||
order.UpdatedAt = now
|
||||
if order.RawStateJSON == "" {
|
||||
order.RawStateJSON = "{}"
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO orders (
|
||||
client_order_id, broker_order_id, account_id_hash, instrument_uid, trade_date,
|
||||
side, order_type, limit_price, quantity_lots, filled_lots, avg_fill_price,
|
||||
status, commission, attempt_no, raw_state_json, created_at, updated_at
|
||||
) VALUES (
|
||||
:client_order_id, :broker_order_id, :account_id_hash, :instrument_uid, :trade_date,
|
||||
:side, :order_type, :limit_price, :quantity_lots, :filled_lots, :avg_fill_price,
|
||||
:status, :commission, :attempt_no, :raw_state_json, :created_at, :updated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
broker_order_id=VALUES(broker_order_id), filled_lots=VALUES(filled_lots),
|
||||
avg_fill_price=VALUES(avg_fill_price), status=VALUES(status),
|
||||
commission=VALUES(commission), raw_state_json=VALUES(raw_state_json),
|
||||
updated_at=VALUES(updated_at)`, orderRowFromDomain(order))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) UpdateOrderStatus(ctx context.Context, clientOrderID string, status domain.OrderStatus, filledLots int64, rawJSON string) error {
|
||||
if rawJSON == "" {
|
||||
rawJSON = "{}"
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
UPDATE orders SET status=?, filled_lots=?, raw_state_json=?, updated_at=UTC_TIMESTAMP(3)
|
||||
WHERE client_order_id=?`, status, filledLots, rawJSON, clientOrderID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListActiveOrders(ctx context.Context, accountIDHash string) ([]domain.Order, error) {
|
||||
var rows []orderRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM orders
|
||||
WHERE account_id_hash=? AND status IN ('NEW','SENT','PARTIALLY_FILLED')
|
||||
ORDER BY created_at`, accountIDHash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Order, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListOrders(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Order, error) {
|
||||
var rows []orderRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM orders
|
||||
WHERE account_id_hash=? AND trade_date BETWEEN ? AND ?
|
||||
ORDER BY created_at`, accountIDHash, dateOnly(from), dateOnly(to)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Order, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpsertPosition(ctx context.Context, position domain.Position) error {
|
||||
if position.UpdatedAt.IsZero() {
|
||||
position.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO positions (
|
||||
id, account_id_hash, instrument_uid, open_trade_date, lots, lot_size, exit_filled_lots,
|
||||
avg_buy_price, avg_sell_price, status, gross_pnl, net_pnl, commission_total,
|
||||
realized_edge_bps, opened_at, closed_at, updated_at
|
||||
) VALUES (
|
||||
NULLIF(:id, 0), :account_id_hash, :instrument_uid, :open_trade_date, :lots, :lot_size, :exit_filled_lots,
|
||||
:avg_buy_price, :avg_sell_price, :status, :gross_pnl, :net_pnl, :commission_total,
|
||||
:realized_edge_bps, :opened_at, :closed_at, :updated_at
|
||||
) ON DUPLICATE KEY UPDATE
|
||||
lots=VALUES(lots), lot_size=VALUES(lot_size), exit_filled_lots=VALUES(exit_filled_lots), avg_buy_price=VALUES(avg_buy_price), avg_sell_price=VALUES(avg_sell_price),
|
||||
status=VALUES(status), gross_pnl=VALUES(gross_pnl), net_pnl=VALUES(net_pnl),
|
||||
commission_total=VALUES(commission_total), realized_edge_bps=VALUES(realized_edge_bps),
|
||||
opened_at=VALUES(opened_at), closed_at=VALUES(closed_at), updated_at=VALUES(updated_at)`, positionRowFromDomain(position))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListOpenPositions(ctx context.Context, accountIDHash string) ([]domain.Position, error) {
|
||||
var rows []positionRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM positions
|
||||
WHERE account_id_hash=? AND status NOT IN ('NO_POSITION','EXIT_FILLED','QUARANTINE')
|
||||
ORDER BY updated_at`, accountIDHash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Position, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) ListPositions(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Position, error) {
|
||||
var rows []positionRow
|
||||
if err := r.selectContext(ctx, &rows, `
|
||||
SELECT * FROM positions
|
||||
WHERE account_id_hash=? AND open_trade_date BETWEEN ? AND ?
|
||||
ORDER BY updated_at`, accountIDHash, dateOnly(from), dateOnly(to)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]domain.Position, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, row.domain())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Repository) InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error {
|
||||
if event.TS.IsZero() {
|
||||
event.TS = time.Now().UTC()
|
||||
}
|
||||
if event.ContextJSON == "" {
|
||||
event.ContextJSON = "{}"
|
||||
}
|
||||
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
|
||||
INSERT INTO risk_events (ts, severity, event_type, instrument_uid, message, raw_context_json)
|
||||
VALUES (:ts, :severity, :event_type, :instrument_uid, :message, :raw_context_json)`, riskEventRowFromDomain(event))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error) {
|
||||
var sent int
|
||||
err := r.getContext(ctx, &sent, `
|
||||
SELECT orders_sent FROM free_order_counters WHERE trade_date=? AND instrument_uid=?`, dateOnly(tradeDate), instrumentUID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, nil
|
||||
}
|
||||
return sent, err
|
||||
}
|
||||
|
||||
func (r *Repository) IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO free_order_counters (trade_date, instrument_uid, orders_sent)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE orders_sent=orders_sent+VALUES(orders_sent)`, dateOnly(tradeDate), instrumentUID, delta)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error) {
|
||||
var row struct {
|
||||
State string `db:"state"`
|
||||
Halted bool `db:"halted"`
|
||||
HaltReason sql.NullString `db:"halt_reason"`
|
||||
}
|
||||
if err := r.getContext(ctx, &row, `SELECT state, halted, halt_reason FROM system_state WHERE id=1`); err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
return domain.SystemState(row.State), row.Halted, row.HaltReason.String, nil
|
||||
}
|
||||
|
||||
func (r *Repository) SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error {
|
||||
if contextJSON == "" {
|
||||
contextJSON = "{}"
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO system_state (id, state, mode, halted, halt_reason, last_heartbeat, context_json)
|
||||
VALUES (1, ?, ?, ?, ?, UTC_TIMESTAMP(3), ?)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
state=VALUES(state), mode=VALUES(mode), halted=VALUES(halted),
|
||||
halt_reason=VALUES(halt_reason), last_heartbeat=VALUES(last_heartbeat),
|
||||
context_json=VALUES(context_json)`, state, mode, halted, nullableString(reason), contextJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) Unhalt(ctx context.Context, reason string) error {
|
||||
return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
|
||||
state, halted, haltReason, err := repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !halted && state != domain.StateHalted {
|
||||
return fmt.Errorf("system is not halted")
|
||||
}
|
||||
if err := repo.InsertRiskEvent(ctx, domain.RiskEvent{
|
||||
TS: time.Now().UTC(),
|
||||
Severity: domain.SeverityInfo,
|
||||
EventType: "manual_unhalt",
|
||||
Message: fmt.Sprintf("%s (previous halt: %s)", reason, haltReason),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
mode := domain.ModePaper
|
||||
if txRepo, ok := repo.(*Repository); ok {
|
||||
currentMode, err := txRepo.getSystemMode(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mode = currentMode
|
||||
}
|
||||
return repo.SaveSystemState(ctx, domain.StateInit, mode, false, "", `{"manual_unhalt":true}`)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Repository) getSystemMode(ctx context.Context) (domain.Mode, error) {
|
||||
var raw string
|
||||
if err := r.getContext(ctx, &raw, `SELECT mode FROM system_state WHERE id=1`); err != nil {
|
||||
return "", err
|
||||
}
|
||||
mode, err := domain.ParseMode(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
func (r *Repository) WasDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) (bool, error) {
|
||||
var count int
|
||||
if err := r.getContext(ctx, &count, `
|
||||
SELECT COUNT(*) FROM daily_reports WHERE report_date=? AND account_id_hash=?`, dateOnly(reportDate), accountIDHash); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *Repository) MarkDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) error {
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO daily_reports (report_date, account_id_hash, sent_at)
|
||||
VALUES (?, ?, UTC_TIMESTAMP(3))
|
||||
ON DUPLICATE KEY UPDATE sent_at=sent_at`, dateOnly(reportDate), accountIDHash)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) InsertReconciliation(ctx context.Context, ts time.Time, diffJSON string, hasDiff bool) error {
|
||||
if ts.IsZero() {
|
||||
ts = time.Now().UTC()
|
||||
}
|
||||
if diffJSON == "" {
|
||||
diffJSON = "[]"
|
||||
}
|
||||
_, err := r.execer().ExecContext(ctx, `
|
||||
INSERT INTO reconciliations (ts, has_diff, diff_json)
|
||||
VALUES (?, ?, ?)`, ts, hasDiff, diffJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
func dateOnly(t time.Time) time.Time {
|
||||
y, m, d := t.UTC().Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func nullableString(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type instrumentRow struct {
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
Figi sql.NullString `db:"figi"`
|
||||
Ticker string `db:"ticker"`
|
||||
ClassCode string `db:"class_code"`
|
||||
Name string `db:"name"`
|
||||
Lot int64 `db:"lot"`
|
||||
MinPriceIncrement decimal.Decimal `db:"min_price_increment"`
|
||||
Currency string `db:"currency"`
|
||||
Enabled bool `db:"enabled"`
|
||||
FundType string `db:"fund_type"`
|
||||
ExpectedCommissionBpsPerSide decimal.Decimal `db:"expected_commission_bps_per_side"`
|
||||
FreeOrderLimitPerDay int `db:"free_order_limit_per_day"`
|
||||
Quarantine bool `db:"quarantine"`
|
||||
QuarantineReason sql.NullString `db:"quarantine_reason"`
|
||||
ExcludeReason sql.NullString `db:"exclude_reason"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func instrumentRowFromDomain(instrument domain.Instrument) instrumentRow {
|
||||
return instrumentRow{
|
||||
InstrumentUID: instrument.InstrumentUID,
|
||||
Figi: sql.NullString{String: instrument.Figi, Valid: instrument.Figi != ""},
|
||||
Ticker: instrument.Ticker,
|
||||
ClassCode: instrument.ClassCode,
|
||||
Name: instrument.Name,
|
||||
Lot: instrument.Lot,
|
||||
MinPriceIncrement: instrument.MinPriceIncrement,
|
||||
Currency: instrument.Currency,
|
||||
Enabled: instrument.Enabled,
|
||||
FundType: instrument.FundType,
|
||||
ExpectedCommissionBpsPerSide: instrument.ExpectedCommissionBpsPerSide,
|
||||
FreeOrderLimitPerDay: instrument.FreeOrderLimitPerDay,
|
||||
Quarantine: instrument.Quarantine,
|
||||
QuarantineReason: sql.NullString{String: instrument.QuarantineReason, Valid: instrument.QuarantineReason != ""},
|
||||
ExcludeReason: sql.NullString{String: instrument.ExcludeReason, Valid: instrument.ExcludeReason != ""},
|
||||
UpdatedAt: instrument.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func replaceInstrumentRowFromDomain(oldInstrumentUID string, instrument domain.Instrument) map[string]any {
|
||||
row := instrumentRowFromDomain(instrument)
|
||||
return map[string]any{
|
||||
"instrument_uid": row.InstrumentUID,
|
||||
"figi": row.Figi,
|
||||
"ticker": row.Ticker,
|
||||
"class_code": row.ClassCode,
|
||||
"name": row.Name,
|
||||
"lot": row.Lot,
|
||||
"min_price_increment": row.MinPriceIncrement,
|
||||
"currency": row.Currency,
|
||||
"enabled": row.Enabled,
|
||||
"fund_type": row.FundType,
|
||||
"expected_commission_bps_per_side": row.ExpectedCommissionBpsPerSide,
|
||||
"free_order_limit_per_day": row.FreeOrderLimitPerDay,
|
||||
"quarantine": row.Quarantine,
|
||||
"quarantine_reason": row.QuarantineReason,
|
||||
"exclude_reason": row.ExcludeReason,
|
||||
"updated_at": row.UpdatedAt,
|
||||
"old_instrument_uid": oldInstrumentUID,
|
||||
}
|
||||
}
|
||||
|
||||
func (r instrumentRow) domain() domain.Instrument {
|
||||
return domain.Instrument{
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
Figi: r.Figi.String,
|
||||
Ticker: r.Ticker,
|
||||
ClassCode: r.ClassCode,
|
||||
Name: r.Name,
|
||||
Lot: r.Lot,
|
||||
MinPriceIncrement: r.MinPriceIncrement,
|
||||
Currency: r.Currency,
|
||||
Enabled: r.Enabled,
|
||||
FundType: r.FundType,
|
||||
ExpectedCommissionBpsPerSide: r.ExpectedCommissionBpsPerSide,
|
||||
FreeOrderLimitPerDay: r.FreeOrderLimitPerDay,
|
||||
Quarantine: r.Quarantine,
|
||||
QuarantineReason: r.QuarantineReason.String,
|
||||
ExcludeReason: r.ExcludeReason.String,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
//go:build integration
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mariadb"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func TestRepositoryMariaDBMigrationsAndRoundTrip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
container, err := mariadb.Run(ctx,
|
||||
"mariadb:11.4",
|
||||
mariadb.WithDatabase("overnight_bot"),
|
||||
mariadb.WithUsername("bot"),
|
||||
mariadb.WithPassword("bot"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := testcontainers.TerminateContainer(container); err != nil {
|
||||
t.Logf("terminate mariadb: %v", err)
|
||||
}
|
||||
})
|
||||
dsn, err := container.ConnectionString(ctx, "parseTime=true", "loc=UTC", "multiStatements=true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db, err := sqlx.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := ApplyMigrations(ctx, db.DB); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
repo := NewRepository(db)
|
||||
instrument := domain.Instrument{
|
||||
InstrumentUID: "uid-trur",
|
||||
Ticker: "TRUR",
|
||||
ClassCode: "TQTF",
|
||||
Name: "TRUR",
|
||||
Lot: 1,
|
||||
MinPriceIncrement: decimal.NewFromFloat(0.0001),
|
||||
Currency: "RUB",
|
||||
Enabled: true,
|
||||
}
|
||||
if err := repo.ReplaceInstrument(ctx, "PENDING:TRUR", instrument); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tradeDate := time.Date(2026, 6, 7, 0, 0, 0, 0, time.UTC)
|
||||
position := domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid-trur",
|
||||
OpenTradeDate: tradeDate,
|
||||
Lots: 10,
|
||||
AvgBuyPrice: decimal.NewFromInt(100),
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
}
|
||||
if err := repo.UpsertPosition(ctx, position); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
position.Lots = 8
|
||||
position.ExitFilledLots = 2
|
||||
if err := repo.UpsertPosition(ctx, position); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var count int
|
||||
if err := db.GetContext(ctx, &count, `
|
||||
SELECT COUNT(*) FROM positions WHERE account_id_hash='hash' AND instrument_uid='uid-trur' AND open_trade_date=?`, tradeDate); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("positions count=%d, want 1", count)
|
||||
}
|
||||
if err := repo.MarkDailyReportSent(ctx, tradeDate, "hash"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sent, err := repo.WasDailyReportSent(ctx, tradeDate, "hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sent {
|
||||
t.Fatalf("daily report marker was not persisted")
|
||||
}
|
||||
if err := repo.UpsertOrder(ctx, domain.Order{
|
||||
ClientOrderID: "bad",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "missing",
|
||||
TradeDate: tradeDate,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: decimal.NewFromInt(100),
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusSent,
|
||||
RawStateJSON: "{}",
|
||||
}); err == nil {
|
||||
t.Fatalf("expected FK failure for missing instrument")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,338 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
type candleRow struct {
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
Open decimal.Decimal `db:"open"`
|
||||
High decimal.Decimal `db:"high"`
|
||||
Low decimal.Decimal `db:"low"`
|
||||
Close decimal.Decimal `db:"close"`
|
||||
VolumeLots decimal.Decimal `db:"volume_lots"`
|
||||
Source string `db:"source"`
|
||||
LoadedAt time.Time `db:"loaded_at"`
|
||||
}
|
||||
|
||||
func candleRowFromDomain(candle domain.Candle) candleRow {
|
||||
return candleRow{
|
||||
InstrumentUID: candle.InstrumentUID,
|
||||
TradeDate: dateOnly(candle.TradeDate),
|
||||
Open: candle.Open,
|
||||
High: candle.High,
|
||||
Low: candle.Low,
|
||||
Close: candle.Close,
|
||||
VolumeLots: candle.VolumeLots,
|
||||
Source: candle.Source,
|
||||
LoadedAt: candle.LoadedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r candleRow) domain() domain.Candle {
|
||||
return domain.Candle{
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
TradeDate: r.TradeDate,
|
||||
Open: r.Open,
|
||||
High: r.High,
|
||||
Low: r.Low,
|
||||
Close: r.Close,
|
||||
VolumeLots: r.VolumeLots,
|
||||
Source: r.Source,
|
||||
LoadedAt: r.LoadedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type featureRow struct {
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
ROn decimal.Decimal `db:"r_on"`
|
||||
RDay decimal.Decimal `db:"r_day"`
|
||||
MuOn60 decimal.Decimal `db:"mu_on_60"`
|
||||
MuOn252 decimal.Decimal `db:"mu_on_252"`
|
||||
SigmaOn60 decimal.Decimal `db:"sigma_on_60"`
|
||||
TStatOn60 decimal.Decimal `db:"tstat_on_60"`
|
||||
WinOn60 decimal.Decimal `db:"win_on_60"`
|
||||
EWMAOn decimal.Decimal `db:"ewma_on"`
|
||||
SpreadBps decimal.Decimal `db:"spread_bps"`
|
||||
HalfSpreadBps decimal.Decimal `db:"half_spread_bps"`
|
||||
TickBps decimal.Decimal `db:"tick_bps"`
|
||||
ADV20 decimal.Decimal `db:"adv_20"`
|
||||
ExpectedCostBps decimal.Decimal `db:"expected_cost_bps"`
|
||||
NetEdgeBps decimal.Decimal `db:"net_edge_bps"`
|
||||
EntryIntervalVolume decimal.Decimal `db:"entry_interval_volume"`
|
||||
ExitIntervalVolume decimal.Decimal `db:"exit_interval_volume"`
|
||||
CalculatedAt time.Time `db:"calculated_at"`
|
||||
}
|
||||
|
||||
func featureRowFromDomain(feature domain.FeatureSet) featureRow {
|
||||
return featureRow{
|
||||
InstrumentUID: feature.InstrumentUID,
|
||||
TradeDate: dateOnly(feature.TradeDate),
|
||||
ROn: feature.ROn,
|
||||
RDay: feature.RDay,
|
||||
MuOn60: feature.MuOn60,
|
||||
MuOn252: feature.MuOn252,
|
||||
SigmaOn60: feature.SigmaOn60,
|
||||
TStatOn60: feature.TStatOn60,
|
||||
WinOn60: feature.WinOn60,
|
||||
EWMAOn: feature.EWMAOn,
|
||||
SpreadBps: feature.SpreadBps,
|
||||
HalfSpreadBps: feature.HalfSpreadBps,
|
||||
TickBps: feature.TickBps,
|
||||
ADV20: feature.ADV20,
|
||||
ExpectedCostBps: feature.ExpectedCostBps,
|
||||
NetEdgeBps: feature.NetEdgeBps,
|
||||
EntryIntervalVolume: feature.EntryIntervalVolume,
|
||||
ExitIntervalVolume: feature.ExitIntervalVolume,
|
||||
CalculatedAt: feature.CalculatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r featureRow) domain() domain.FeatureSet {
|
||||
return domain.FeatureSet{
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
TradeDate: r.TradeDate,
|
||||
ROn: r.ROn,
|
||||
RDay: r.RDay,
|
||||
MuOn60: r.MuOn60,
|
||||
MuOn252: r.MuOn252,
|
||||
SigmaOn60: r.SigmaOn60,
|
||||
TStatOn60: r.TStatOn60,
|
||||
WinOn60: r.WinOn60,
|
||||
EWMAOn: r.EWMAOn,
|
||||
SpreadBps: r.SpreadBps,
|
||||
HalfSpreadBps: r.HalfSpreadBps,
|
||||
TickBps: r.TickBps,
|
||||
ADV20: r.ADV20,
|
||||
ExpectedCostBps: r.ExpectedCostBps,
|
||||
NetEdgeBps: r.NetEdgeBps,
|
||||
EntryIntervalVolume: r.EntryIntervalVolume,
|
||||
ExitIntervalVolume: r.ExitIntervalVolume,
|
||||
CalculatedAt: r.CalculatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type signalRow struct {
|
||||
ID int64 `db:"id"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
Decision string `db:"decision"`
|
||||
Score decimal.Decimal `db:"score"`
|
||||
NetEdgeBps decimal.Decimal `db:"net_edge_bps"`
|
||||
TargetNotional decimal.Decimal `db:"target_notional"`
|
||||
TargetLots int64 `db:"target_lots"`
|
||||
RejectReason sql.NullString `db:"reject_reason"`
|
||||
ContextJSON sql.NullString `db:"context_json"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
func signalRowFromDomain(signal domain.Signal) signalRow {
|
||||
return signalRow{
|
||||
ID: signal.ID,
|
||||
TradeDate: dateOnly(signal.TradeDate),
|
||||
InstrumentUID: signal.InstrumentUID,
|
||||
Decision: string(signal.Decision),
|
||||
Score: signal.Score,
|
||||
NetEdgeBps: signal.NetEdgeBps,
|
||||
TargetNotional: signal.TargetNotional,
|
||||
TargetLots: signal.TargetLots,
|
||||
RejectReason: sql.NullString{String: signal.RejectReason, Valid: signal.RejectReason != ""},
|
||||
ContextJSON: sql.NullString{String: signal.ContextJSON, Valid: signal.ContextJSON != ""},
|
||||
CreatedAt: signal.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r signalRow) domain() domain.Signal {
|
||||
return domain.Signal{
|
||||
ID: r.ID,
|
||||
TradeDate: r.TradeDate,
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
Decision: domain.SignalDecision(r.Decision),
|
||||
Score: r.Score,
|
||||
NetEdgeBps: r.NetEdgeBps,
|
||||
TargetNotional: r.TargetNotional,
|
||||
TargetLots: r.TargetLots,
|
||||
RejectReason: r.RejectReason.String,
|
||||
ContextJSON: r.ContextJSON.String,
|
||||
CreatedAt: r.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type orderRow struct {
|
||||
ClientOrderID string `db:"client_order_id"`
|
||||
BrokerOrderID sql.NullString `db:"broker_order_id"`
|
||||
AccountIDHash string `db:"account_id_hash"`
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
TradeDate time.Time `db:"trade_date"`
|
||||
Side string `db:"side"`
|
||||
OrderType string `db:"order_type"`
|
||||
LimitPrice decimal.Decimal `db:"limit_price"`
|
||||
QuantityLots int64 `db:"quantity_lots"`
|
||||
FilledLots int64 `db:"filled_lots"`
|
||||
AvgFillPrice decimal.Decimal `db:"avg_fill_price"`
|
||||
Status string `db:"status"`
|
||||
Commission decimal.Decimal `db:"commission"`
|
||||
AttemptNo int `db:"attempt_no"`
|
||||
RawStateJSON sql.NullString `db:"raw_state_json"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func orderRowFromDomain(order domain.Order) orderRow {
|
||||
return orderRow{
|
||||
ClientOrderID: order.ClientOrderID,
|
||||
BrokerOrderID: sql.NullString{
|
||||
String: order.BrokerOrderID,
|
||||
Valid: order.BrokerOrderID != "",
|
||||
},
|
||||
AccountIDHash: order.AccountIDHash,
|
||||
InstrumentUID: order.InstrumentUID,
|
||||
TradeDate: dateOnly(order.TradeDate),
|
||||
Side: string(order.Side),
|
||||
OrderType: string(order.OrderType),
|
||||
LimitPrice: order.LimitPrice,
|
||||
QuantityLots: order.QuantityLots,
|
||||
FilledLots: order.FilledLots,
|
||||
AvgFillPrice: order.AvgFillPrice,
|
||||
Status: string(order.Status),
|
||||
Commission: order.Commission,
|
||||
AttemptNo: order.AttemptNo,
|
||||
RawStateJSON: sql.NullString{
|
||||
String: order.RawStateJSON,
|
||||
Valid: order.RawStateJSON != "",
|
||||
},
|
||||
CreatedAt: order.CreatedAt,
|
||||
UpdatedAt: order.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r orderRow) domain() domain.Order {
|
||||
return domain.Order{
|
||||
ClientOrderID: r.ClientOrderID,
|
||||
BrokerOrderID: r.BrokerOrderID.String,
|
||||
AccountIDHash: r.AccountIDHash,
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
TradeDate: r.TradeDate,
|
||||
Side: domain.Side(r.Side),
|
||||
OrderType: domain.OrderType(r.OrderType),
|
||||
LimitPrice: r.LimitPrice,
|
||||
QuantityLots: r.QuantityLots,
|
||||
FilledLots: r.FilledLots,
|
||||
AvgFillPrice: r.AvgFillPrice,
|
||||
Status: domain.OrderStatus(r.Status),
|
||||
Commission: r.Commission,
|
||||
AttemptNo: r.AttemptNo,
|
||||
RawStateJSON: r.RawStateJSON.String,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type positionRow struct {
|
||||
ID int64 `db:"id"`
|
||||
AccountIDHash string `db:"account_id_hash"`
|
||||
InstrumentUID string `db:"instrument_uid"`
|
||||
OpenTradeDate time.Time `db:"open_trade_date"`
|
||||
Lots int64 `db:"lots"`
|
||||
Lot int64 `db:"lot_size"`
|
||||
ExitFilledLots int64 `db:"exit_filled_lots"`
|
||||
AvgBuyPrice decimal.Decimal `db:"avg_buy_price"`
|
||||
AvgSellPrice decimal.Decimal `db:"avg_sell_price"`
|
||||
Status string `db:"status"`
|
||||
GrossPnL decimal.Decimal `db:"gross_pnl"`
|
||||
NetPnL decimal.Decimal `db:"net_pnl"`
|
||||
CommissionTotal decimal.Decimal `db:"commission_total"`
|
||||
RealizedEdgeBps decimal.Decimal `db:"realized_edge_bps"`
|
||||
OpenedAt sql.NullTime `db:"opened_at"`
|
||||
ClosedAt sql.NullTime `db:"closed_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func positionRowFromDomain(position domain.Position) positionRow {
|
||||
lot := position.Lot
|
||||
if lot <= 0 {
|
||||
lot = 1
|
||||
}
|
||||
return positionRow{
|
||||
ID: position.ID,
|
||||
AccountIDHash: position.AccountIDHash,
|
||||
InstrumentUID: position.InstrumentUID,
|
||||
OpenTradeDate: dateOnly(position.OpenTradeDate),
|
||||
Lots: position.Lots,
|
||||
Lot: lot,
|
||||
ExitFilledLots: position.ExitFilledLots,
|
||||
AvgBuyPrice: position.AvgBuyPrice,
|
||||
AvgSellPrice: position.AvgSellPrice,
|
||||
Status: string(position.Status),
|
||||
GrossPnL: position.GrossPnL,
|
||||
NetPnL: position.NetPnL,
|
||||
CommissionTotal: position.CommissionTotal,
|
||||
RealizedEdgeBps: position.RealizedEdgeBps,
|
||||
OpenedAt: nullableTime(position.OpenedAt),
|
||||
ClosedAt: nullableTime(position.ClosedAt),
|
||||
UpdatedAt: position.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (r positionRow) domain() domain.Position {
|
||||
return domain.Position{
|
||||
ID: r.ID,
|
||||
AccountIDHash: r.AccountIDHash,
|
||||
InstrumentUID: r.InstrumentUID,
|
||||
OpenTradeDate: r.OpenTradeDate,
|
||||
Lots: r.Lots,
|
||||
Lot: r.Lot,
|
||||
ExitFilledLots: r.ExitFilledLots,
|
||||
AvgBuyPrice: r.AvgBuyPrice,
|
||||
AvgSellPrice: r.AvgSellPrice,
|
||||
Status: domain.PositionStatus(r.Status),
|
||||
GrossPnL: r.GrossPnL,
|
||||
NetPnL: r.NetPnL,
|
||||
CommissionTotal: r.CommissionTotal,
|
||||
RealizedEdgeBps: r.RealizedEdgeBps,
|
||||
OpenedAt: timePtr(r.OpenedAt),
|
||||
ClosedAt: timePtr(r.ClosedAt),
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
type riskEventRow struct {
|
||||
TS time.Time `db:"ts"`
|
||||
Severity string `db:"severity"`
|
||||
EventType string `db:"event_type"`
|
||||
InstrumentUID sql.NullString `db:"instrument_uid"`
|
||||
Message string `db:"message"`
|
||||
ContextJSON string `db:"raw_context_json"`
|
||||
}
|
||||
|
||||
func riskEventRowFromDomain(event domain.RiskEvent) riskEventRow {
|
||||
return riskEventRow{
|
||||
TS: event.TS,
|
||||
Severity: string(event.Severity),
|
||||
EventType: event.EventType,
|
||||
InstrumentUID: sql.NullString{String: event.InstrumentUID, Valid: event.InstrumentUID != ""},
|
||||
Message: event.Message,
|
||||
ContextJSON: event.ContextJSON,
|
||||
}
|
||||
}
|
||||
|
||||
func nullableTime(t *time.Time) sql.NullTime {
|
||||
if t == nil {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return sql.NullTime{Time: *t, Valid: true}
|
||||
}
|
||||
|
||||
func timePtr(t sql.NullTime) *time.Time {
|
||||
if !t.Valid {
|
||||
return nil
|
||||
}
|
||||
return &t.Time
|
||||
}
|
||||
Reference in New Issue
Block a user