diff --git a/internal/bot/telegram/telegram.go b/internal/bot/telegram/telegram.go index 8157b2d..f3972d9 100644 --- a/internal/bot/telegram/telegram.go +++ b/internal/bot/telegram/telegram.go @@ -14,14 +14,19 @@ import ( "github.com/shopspring/decimal" ) -const alertsPageSize = 5 +const ( + alertsPageSize = 5 + instrPageSize = 4 +) // Usecase defines the business logic operations required by the bot. type Usecase interface { RegisterNewUser(ctx context.Context, user *entities.User) error UserByID(ctx context.Context, userID entities.UserID) (*entities.User, error) UserByTgID(ctx context.Context, telegramID entities.TelegramID) (*entities.User, error) - InstrumentList(ctx context.Context, offset, limit int) ([]entities.Instrument, error) + InstrumentList(ctx context.Context, userID entities.UserID, offset, limit int) ([]entities.Instrument, error) + AddUserInstrument(ctx context.Context, userID entities.UserID, base, quote string) (*entities.Instrument, error) + RemoveUserInstrument(ctx context.Context, userID entities.UserID, instrumentID entities.InstrumentID) error CreateAlert(ctx context.Context, alert *entities.Alert) (entities.AlertID, error) Alert(ctx context.Context, alertID entities.AlertID) (*entities.Alert, error) Alerts(ctx context.Context, userID entities.UserID, offset, limit int) ([]entities.Alert, error) @@ -35,9 +40,10 @@ type Alerter interface { RemoveAlert(id entities.AlertID) } -// PriceProvider fetches the current market price for an instrument. +// PriceProvider fetches the current market price for an instrument and validates pairs. type PriceProvider interface { Price(ctx context.Context, instrument entities.Instrument) (*entities.Price, error) + InstrumentExists(ctx context.Context, base, quote string) (bool, error) } // Reply keyboard button labels. @@ -45,6 +51,7 @@ const ( btnAddAlert = "Add Alert" btnMyAlerts = "My Alerts" btnInstruments = "Instruments" + btnAddPair = "Add Pair" ) type flowStep string @@ -53,6 +60,7 @@ const ( stepAddAlertPrice flowStep = "add_alert_price" stepAddAlertAwaitType flowStep = "add_alert_await_type" // price entered, waiting for type callback stepEditAlertPrice flowStep = "edit_alert_price" + stepAddPair flowStep = "add_pair" ) type userState struct { @@ -213,6 +221,9 @@ func (b *Bot) handleMessage(ctx context.Context, msg *tgbotapi.Message) { case btnInstruments: b.cmdInstruments(ctx, tgID, chatID) return + case btnAddPair: + b.cmdAddPair(ctx, tgID, chatID) + return } // Route plain-text input to the active multi-step flow. @@ -224,6 +235,8 @@ func (b *Bot) handleMessage(ctx context.Context, msg *tgbotapi.Message) { b.send(chatID, "Please select the alert type using the buttons above.") case stepEditAlertPrice: b.handleEditAlertPrice(ctx, tgID, chatID, msg.Text, state) + case stepAddPair: + b.handleAddPairInput(ctx, tgID, chatID, msg.Text) default: b.sendMenu(chatID, "Use the menu below or /add_alert to set a price alert.") } @@ -239,10 +252,12 @@ func (b *Bot) handleCommand(ctx context.Context, tgID entities.TelegramID, chatI b.cmdAddAlert(ctx, tgID, chatID) case "my_alerts": b.cmdMyAlerts(ctx, tgID, chatID) + case "add_pair": + b.cmdAddPair(ctx, tgID, chatID) case "cancel": b.cmdCancel(tgID, chatID) default: - b.sendMenu(chatID, "Unknown command.\n\nAvailable commands:\n/start — register\n/instruments — list trading pairs\n/add_alert — create a price alert\n/my_alerts — view your alerts\n/cancel — cancel current operation") + b.sendMenu(chatID, "Unknown command.\n\nAvailable commands:\n/start — register\n/instruments — list trading pairs\n/add_alert — create a price alert\n/my_alerts — view your alerts\n/add_pair — add a custom trading pair\n/cancel — cancel current operation") } } @@ -260,6 +275,46 @@ func (b *Bot) handleCallback(ctx context.Context, cb *tgbotapi.CallbackQuery) { instrID := entities.InstrumentID(strings.TrimPrefix(data, "instrument:")) b.handleInstrumentSelected(ctx, tgID, chatID, instrID) + case strings.HasPrefix(data, "add_alert_instr_page:"): + page, _ := strconv.Atoi(strings.TrimPrefix(data, "add_alert_instr_page:")) + b.handleAddAlertInstrPage(ctx, tgID, chatID, messageID, page) + + case strings.HasPrefix(data, "instr_page:"): + page, _ := strconv.Atoi(strings.TrimPrefix(data, "instr_page:")) + b.handleInstrumentsPage(ctx, tgID, chatID, messageID, page) + + case strings.HasPrefix(data, "instr_select:"): + // format: instr_select:: + rest := strings.TrimPrefix(data, "instr_select:") + idx := strings.LastIndex(rest, ":") + instrID := entities.InstrumentID(rest[:idx]) + page, _ := strconv.Atoi(rest[idx+1:]) + b.handleInstrumentSelect(ctx, tgID, chatID, messageID, instrID, page) + + case strings.HasPrefix(data, "instr_remove:"): + // format: instr_remove:: + rest := strings.TrimPrefix(data, "instr_remove:") + idx := strings.LastIndex(rest, ":") + instrID := entities.InstrumentID(rest[:idx]) + page, _ := strconv.Atoi(rest[idx+1:]) + b.handleInstrumentRemoveConfirm(chatID, messageID, instrID, page) + + case strings.HasPrefix(data, "confirm_remove_instr:"): + // format: confirm_remove_instr:: + rest := strings.TrimPrefix(data, "confirm_remove_instr:") + idx := strings.LastIndex(rest, ":") + instrID := entities.InstrumentID(rest[:idx]) + page, _ := strconv.Atoi(rest[idx+1:]) + b.handleInstrumentRemoveDo(ctx, tgID, chatID, messageID, instrID, page) + + case strings.HasPrefix(data, "cancel_remove_instr:"): + // format: cancel_remove_instr:: + rest := strings.TrimPrefix(data, "cancel_remove_instr:") + idx := strings.LastIndex(rest, ":") + instrID := entities.InstrumentID(rest[:idx]) + page, _ := strconv.Atoi(rest[idx+1:]) + b.handleInstrumentSelect(ctx, tgID, chatID, messageID, instrID, page) + case strings.HasPrefix(data, "alerts_page:"), strings.HasPrefix(data, "alerts_back:"): rest := data[strings.Index(data, ":")+1:] page, _ := strconv.Atoi(rest) @@ -331,8 +386,13 @@ func (b *Bot) cmdStart(ctx context.Context, tgID entities.TelegramID, chatID int b.sendMenu(chatID, "Welcome! You are now registered.\n\nUse the menu below to get started.") } -func (b *Bot) cmdInstruments(ctx context.Context, _ entities.TelegramID, chatID int64) { - instruments, err := b.usecase.InstrumentList(ctx, 0, 50) +func (b *Bot) cmdInstruments(ctx context.Context, tgID entities.TelegramID, chatID int64) { + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { + return + } + + instruments, err := b.usecase.InstrumentList(ctx, user.ID, 0, 200) if err != nil { b.log.Error("failed to list instruments", "err", err) b.sendMenu(chatID, "Failed to load instruments.") @@ -343,41 +403,32 @@ func (b *Bot) cmdInstruments(ctx context.Context, _ entities.TelegramID, chatID return } - var sb strings.Builder - sb.WriteString("Available trading pairs:\n\n") - for _, instr := range instruments { - fmt.Fprintf(&sb, "- %s/%s\n", instr.BaseCurrency, instr.QuoteCurrency) - } - b.sendMenu(chatID, sb.String()) + text, kb := buildInstrumentsPage(instruments, 0) + msg := tgbotapi.NewMessage(chatID, text) + msg.ReplyMarkup = kb + b.sendMsg(msg) } func (b *Bot) cmdAddAlert(ctx context.Context, tgID entities.TelegramID, chatID int64) { - if _, err := b.requireUser(ctx, tgID, chatID); err != nil { + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { return } - instruments, err := b.usecase.InstrumentList(ctx, 0, 50) + instruments, err := b.usecase.InstrumentList(ctx, user.ID, 0, 200) if err != nil { b.log.Error("failed to list instruments", "err", err) b.sendMenu(chatID, "Failed to load instruments.") return } if len(instruments) == 0 { - b.sendMenu(chatID, "No instruments available.") + b.sendMenu(chatID, "No instruments available. Use \"Add Pair\" to add one.") return } - var rows [][]tgbotapi.InlineKeyboardButton - for _, instr := range instruments { - btn := tgbotapi.NewInlineKeyboardButtonData( - fmt.Sprintf("%s/%s", instr.BaseCurrency, instr.QuoteCurrency), - fmt.Sprintf("instrument:%s", instr.ID), - ) - rows = append(rows, tgbotapi.NewInlineKeyboardRow(btn)) - } - - msg := tgbotapi.NewMessage(chatID, "Select a trading pair:") - msg.ReplyMarkup = tgbotapi.NewInlineKeyboardMarkup(rows...) + text, kb := buildAddAlertInstrPage(instruments, 0) + msg := tgbotapi.NewMessage(chatID, text) + msg.ReplyMarkup = kb b.sendMsg(msg) } @@ -404,6 +455,236 @@ func (b *Bot) cmdMyAlerts(ctx context.Context, tgID entities.TelegramID, chatID b.sendMsg(msg) } +func (b *Bot) cmdAddPair(ctx context.Context, tgID entities.TelegramID, chatID int64) { + if _, err := b.requireUser(ctx, tgID, chatID); err != nil { + return + } + b.setState(tgID, &userState{step: stepAddPair}) + b.send(chatID, "Enter the trading pair in BASE/QUOTE format (e.g. DOGE/USDT):") +} + +func (b *Bot) cmdCancel(tgID entities.TelegramID, chatID int64) { + b.setState(tgID, &userState{}) + b.sendMenu(chatID, "Operation cancelled.") +} + +// --- Instruments pagination --- + +// buildInstrumentsPage constructs a paginated instruments list. +// Numbered buttons let users open a detail view; ◀ ▶ navigate pages. +func buildInstrumentsPage(instruments []entities.Instrument, page int) (string, tgbotapi.InlineKeyboardMarkup) { + total := len(instruments) + totalPages := (total + instrPageSize - 1) / instrPageSize + + start := page * instrPageSize + end := start + instrPageSize + if end > total { + end = total + } + pageItems := instruments[start:end] + + var sb strings.Builder + fmt.Fprintf(&sb, "Trading pairs (page %d/%d):\n\n", page+1, totalPages) + for i, instr := range pageItems { + marker := "" + if !instr.IsGlobal { + marker = " ✦" // marks user-added pairs + } + fmt.Fprintf(&sb, "%d. %s/%s%s\n", start+i+1, instr.BaseCurrency, instr.QuoteCurrency, marker) + } + + var rows [][]tgbotapi.InlineKeyboardButton + + // One button per instrument on this page. + var itemRow []tgbotapi.InlineKeyboardButton + for i, instr := range pageItems { + itemRow = append(itemRow, tgbotapi.NewInlineKeyboardButtonData( + strconv.Itoa(start+i+1), + fmt.Sprintf("instr_select:%s:%d", instr.ID, page), + )) + } + if len(itemRow) > 0 { + rows = append(rows, itemRow) + } + + var navRow []tgbotapi.InlineKeyboardButton + if page > 0 { + navRow = append(navRow, tgbotapi.NewInlineKeyboardButtonData("◀", fmt.Sprintf("instr_page:%d", page-1))) + } + if end < total { + navRow = append(navRow, tgbotapi.NewInlineKeyboardButtonData("▶", fmt.Sprintf("instr_page:%d", page+1))) + } + if len(navRow) > 0 { + rows = append(rows, navRow) + } + + return sb.String(), tgbotapi.NewInlineKeyboardMarkup(rows...) +} + +// handleInstrumentsPage re-fetches instruments and edits the message to show the requested page. +func (b *Bot) handleInstrumentsPage(ctx context.Context, tgID entities.TelegramID, chatID int64, messageID int, page int) { + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { + return + } + + instruments, err := b.usecase.InstrumentList(ctx, user.ID, 0, 200) + if err != nil { + b.log.Error("failed to list instruments", "err", err) + return + } + + totalPages := (len(instruments) + instrPageSize - 1) / instrPageSize + if page >= totalPages { + page = totalPages - 1 + } + + text, kb := buildInstrumentsPage(instruments, page) + edit := tgbotapi.NewEditMessageTextAndMarkup(chatID, messageID, text, kb) + if _, err := b.api.Send(edit); err != nil { + b.log.Error("failed to edit message", "err", err) + } +} + +// handleInstrumentSelect shows instrument detail with a Remove button (for user-added pairs). +func (b *Bot) handleInstrumentSelect(ctx context.Context, tgID entities.TelegramID, chatID int64, messageID int, instrID entities.InstrumentID, page int) { + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { + return + } + + instruments, err := b.usecase.InstrumentList(ctx, user.ID, 0, 200) + if err != nil { + b.log.Error("failed to list instruments", "err", err) + return + } + + var selected entities.Instrument + for _, instr := range instruments { + if instr.ID == instrID { + selected = instr + break + } + } + if selected.ID == "" { + b.editMsgText(chatID, messageID, "Instrument not found.") + return + } + + text := fmt.Sprintf("%s/%s", selected.BaseCurrency, selected.QuoteCurrency) + if !selected.IsGlobal { + text += "\n(user-added pair)" + } + + var btns []tgbotapi.InlineKeyboardButton + if !selected.IsGlobal { + btns = append(btns, tgbotapi.NewInlineKeyboardButtonData( + "Remove", fmt.Sprintf("instr_remove:%s:%d", instrID, page), + )) + } + btns = append(btns, tgbotapi.NewInlineKeyboardButtonData( + "◀ Back", fmt.Sprintf("instr_page:%d", page), + )) + + edit := tgbotapi.NewEditMessageTextAndMarkup(chatID, messageID, text, + tgbotapi.NewInlineKeyboardMarkup(btns)) + if _, err := b.api.Send(edit); err != nil { + b.log.Error("failed to edit message", "err", err) + } +} + +// handleInstrumentRemoveConfirm asks the user to confirm removal. +func (b *Bot) handleInstrumentRemoveConfirm(chatID int64, messageID int, instrID entities.InstrumentID, page int) { + kb := tgbotapi.NewInlineKeyboardMarkup(tgbotapi.NewInlineKeyboardRow( + tgbotapi.NewInlineKeyboardButtonData("Yes, remove", fmt.Sprintf("confirm_remove_instr:%s:%d", instrID, page)), + tgbotapi.NewInlineKeyboardButtonData("Cancel", fmt.Sprintf("cancel_remove_instr:%s:%d", instrID, page)), + )) + edit := tgbotapi.NewEditMessageTextAndMarkup(chatID, messageID, "Remove this pair from your list?", kb) + if _, err := b.api.Send(edit); err != nil { + b.log.Error("failed to edit message", "err", err) + } +} + +// handleInstrumentRemoveDo removes the instrument from the user's list and refreshes the page. +func (b *Bot) handleInstrumentRemoveDo(ctx context.Context, tgID entities.TelegramID, chatID int64, messageID int, instrID entities.InstrumentID, page int) { + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { + return + } + + if err := b.usecase.RemoveUserInstrument(ctx, user.ID, instrID); err != nil { + b.log.Error("failed to remove user instrument", "instrument_id", instrID, "err", err) + b.editMsgText(chatID, messageID, "Failed to remove pair.") + return + } + + b.handleInstrumentsPage(ctx, tgID, chatID, messageID, page) +} + +// buildAddAlertInstrPage builds the instrument selection keyboard for the add-alert flow. +func buildAddAlertInstrPage(instruments []entities.Instrument, page int) (string, tgbotapi.InlineKeyboardMarkup) { + total := len(instruments) + totalPages := (total + instrPageSize - 1) / instrPageSize + + start := page * instrPageSize + end := start + instrPageSize + if end > total { + end = total + } + pageItems := instruments[start:end] + + text := fmt.Sprintf("Select a trading pair (page %d/%d):", page+1, totalPages) + + var rows [][]tgbotapi.InlineKeyboardButton + for _, instr := range pageItems { + btn := tgbotapi.NewInlineKeyboardButtonData( + fmt.Sprintf("%s/%s", instr.BaseCurrency, instr.QuoteCurrency), + fmt.Sprintf("instrument:%s", instr.ID), + ) + rows = append(rows, tgbotapi.NewInlineKeyboardRow(btn)) + } + + var navRow []tgbotapi.InlineKeyboardButton + if page > 0 { + navRow = append(navRow, tgbotapi.NewInlineKeyboardButtonData("◀", fmt.Sprintf("add_alert_instr_page:%d", page-1))) + } + if end < total { + navRow = append(navRow, tgbotapi.NewInlineKeyboardButtonData("▶", fmt.Sprintf("add_alert_instr_page:%d", page+1))) + } + if len(navRow) > 0 { + rows = append(rows, navRow) + } + + return text, tgbotapi.NewInlineKeyboardMarkup(rows...) +} + +// handleAddAlertInstrPage edits the instrument selection message to show a different page. +func (b *Bot) handleAddAlertInstrPage(ctx context.Context, tgID entities.TelegramID, chatID int64, messageID int, page int) { + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { + return + } + + instruments, err := b.usecase.InstrumentList(ctx, user.ID, 0, 200) + if err != nil { + b.log.Error("failed to list instruments", "err", err) + return + } + + totalPages := (len(instruments) + instrPageSize - 1) / instrPageSize + if page >= totalPages { + page = totalPages - 1 + } + + text, kb := buildAddAlertInstrPage(instruments, page) + edit := tgbotapi.NewEditMessageTextAndMarkup(chatID, messageID, text, kb) + if _, err := b.api.Send(edit); err != nil { + b.log.Error("failed to edit message", "err", err) + } +} + +// --- Alerts pagination --- + // buildAlertsPage constructs the message text and inline keyboard for a paginated alerts list. func buildAlertsPage(alerts []entities.Alert, page int) (string, tgbotapi.InlineKeyboardMarkup) { total := len(alerts) @@ -516,16 +797,16 @@ func (b *Bot) handleAlertSelect(ctx context.Context, chatID int64, messageID int } } -func (b *Bot) cmdCancel(tgID entities.TelegramID, chatID int64) { - b.setState(tgID, &userState{}) - b.sendMenu(chatID, "Operation cancelled.") -} - // --- Multi-step flow handlers --- // handleInstrumentSelected fetches the current price, stores it in state, and prompts for a target price. func (b *Bot) handleInstrumentSelected(ctx context.Context, tgID entities.TelegramID, chatID int64, instrID entities.InstrumentID) { - instruments, err := b.usecase.InstrumentList(ctx, 0, 50) + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { + return + } + + instruments, err := b.usecase.InstrumentList(ctx, user.ID, 0, 200) if err != nil { b.log.Error("failed to list instruments", "err", err) b.sendMenu(chatID, "Failed to load instruments.") @@ -701,6 +982,46 @@ func (b *Bot) handleAlertTimeframe(ctx context.Context, tgID entities.TelegramID )) } +// handleAddPairInput processes the user-entered trading pair symbol. +func (b *Bot) handleAddPairInput(ctx context.Context, tgID entities.TelegramID, chatID int64, text string) { + user, err := b.requireUser(ctx, tgID, chatID) + if err != nil { + return + } + + parts := strings.SplitN(strings.TrimSpace(text), "/", 2) + if len(parts) != 2 || strings.TrimSpace(parts[0]) == "" || strings.TrimSpace(parts[1]) == "" { + b.send(chatID, "Invalid format. Please enter the pair as BASE/QUOTE (e.g. DOGE/USDT):") + return + } + + base := strings.ToUpper(strings.TrimSpace(parts[0])) + quote := strings.ToUpper(strings.TrimSpace(parts[1])) + + exists, err := b.provider.InstrumentExists(ctx, base, quote) + if err != nil { + b.log.Error("failed to check instrument existence", "base", base, "quote", quote, "err", err) + b.sendMenu(chatID, "Could not verify the trading pair. Please try again later.") + b.setState(tgID, &userState{}) + return + } + if !exists { + b.send(chatID, fmt.Sprintf("Pair %s/%s is not available on Bybit. Please enter a valid pair:", base, quote)) + return + } + + instr, err := b.usecase.AddUserInstrument(ctx, user.ID, base, quote) + if err != nil { + b.log.Error("failed to add user instrument", "err", err) + b.sendMenu(chatID, "Failed to add the pair. Please try again.") + b.setState(tgID, &userState{}) + return + } + + b.setState(tgID, &userState{}) + b.sendMenu(chatID, fmt.Sprintf("Pair %s/%s has been added to your list!", instr.BaseCurrency, instr.QuoteCurrency)) +} + // detectCondition returns above/below based on target vs current ask. func (b *Bot) detectCondition(state *userState) entities.AlertCondition { if state.currentPrice != nil && state.targetPrice.GreaterThanOrEqual(state.currentPrice.Ask) { @@ -818,6 +1139,7 @@ func menuKeyboard() tgbotapi.ReplyKeyboardMarkup { ), tgbotapi.NewKeyboardButtonRow( tgbotapi.NewKeyboardButton(btnInstruments), + tgbotapi.NewKeyboardButton(btnAddPair), ), ) } diff --git a/internal/entities/instrument.go b/internal/entities/instrument.go index eac2ee2..ba8892d 100644 --- a/internal/entities/instrument.go +++ b/internal/entities/instrument.go @@ -6,4 +6,5 @@ type Instrument struct { ID InstrumentID BaseCurrency string // base currency of the pair. e.g. BTC. QuoteCurrency string // quote currency of the pair. e.g. USDT. + IsGlobal bool // true for pre-seeded pairs visible to all users. } diff --git a/internal/provider/bybit/bybit.go b/internal/provider/bybit/bybit.go index 1486e87..44bc781 100644 --- a/internal/provider/bybit/bybit.go +++ b/internal/provider/bybit/bybit.go @@ -3,6 +3,7 @@ package bybit import ( "context" + "encoding/json" "fmt" "log/slog" "net/http" @@ -100,6 +101,28 @@ func intervalBybit(interval provider.KlineInterval) (string, error) { return i, nil } +// InstrumentExists reports whether base/quote is a valid spot pair on Bybit. +// A non-zero retCode (e.g. unknown symbol) is treated as "not found" — only +// actual transport / parse failures propagate as errors. +func (b *Bybit) InstrumentExists(ctx context.Context, base, quote string) (bool, error) { + req := marketOrderbookReq{ + Category: categorySpot, + Symbol: fmt.Sprintf("%s%s", base, quote), + } + + body, err := b.getRequest(ctx, "/v5/market/orderbook", req) + if err != nil { + return false, fmt.Errorf("failed to check instrument existence: %w", err) + } + + var resp response + if err := json.Unmarshal(body, &resp); err != nil { + return false, fmt.Errorf("failed to parse response: %w", err) + } + + return resp.RetCode == 0, nil +} + // Candles returns OHLC candles for the given interval in the [from, to) range. // It paginates automatically when the range exceeds klineLimit candles per request. func (b *Bybit) Candles(ctx context.Context, instrument entities.Instrument, from, to time.Time, interval provider.KlineInterval) ([]entities.Candle, error) { diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 89bdad1..a4273f6 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -17,6 +17,10 @@ type Provider interface { // The implementation handles pagination automatically when the range exceeds one // request's capacity. Candles(ctx context.Context, instrument entities.Instrument, from, to time.Time, interval KlineInterval) ([]entities.Candle, error) + + // InstrumentExists reports whether the trading pair base/quote is listed on this provider. + // Returns (false, nil) when the symbol is simply not found (as opposed to a network error). + InstrumentExists(ctx context.Context, base, quote string) (bool, error) } type KlineInterval string diff --git a/internal/repository/postgresql/instrument.go b/internal/repository/postgresql/instrument.go index ccd4a04..27877cd 100644 --- a/internal/repository/postgresql/instrument.go +++ b/internal/repository/postgresql/instrument.go @@ -8,15 +8,20 @@ import ( ) const instrumentListQuery = ` -select i.id, c_base.symbol, c_quote.symbol +select i.id, c_base.symbol, c_quote.symbol, i.is_global from instrument i join currency c_base on c_base.id = i.base_currency_id join currency c_quote on c_quote.id = i.quoted_currency_id -order by i.id desc -offset $1 limit $2` +where i.is_global = true + or exists ( + select 1 from user_instrument ui + where ui.instrument_id = i.id and ui.user_id = $1 + ) +order by i.is_global desc, i.id asc +offset $2 limit $3` -func (p *Postgresql) InstrumentList(ctx context.Context, offset, limit int) ([]entities.Instrument, error) { - rows, err := p.db.Query(ctx, instrumentListQuery, offset, limit) +func (p *Postgresql) InstrumentList(ctx context.Context, userID entities.UserID, offset, limit int) ([]entities.Instrument, error) { + rows, err := p.db.Query(ctx, instrumentListQuery, userID, offset, limit) if err != nil { return nil, fmt.Errorf("failed to exec instrumentListQuery: %w", err) } @@ -25,7 +30,7 @@ func (p *Postgresql) InstrumentList(ctx context.Context, offset, limit int) ([]e var instruments []entities.Instrument for rows.Next() { var inst entities.Instrument - if err := rows.Scan(&inst.ID, &inst.BaseCurrency, &inst.QuoteCurrency); err != nil { + if err := rows.Scan(&inst.ID, &inst.BaseCurrency, &inst.QuoteCurrency, &inst.IsGlobal); err != nil { return nil, fmt.Errorf("failed to scan instrument row: %w", err) } instruments = append(instruments, inst) @@ -34,21 +39,62 @@ func (p *Postgresql) InstrumentList(ctx context.Context, offset, limit int) ([]e return instruments, nil } +// createInstrumentQuery upserts both currency symbols then the instrument itself. +// It always returns an ID — the newly inserted row or the existing one on conflict. const createInstrumentQuery = ` -insert into instrument(base_currency_id, quoted_currency_id) -values ( - (select id from currency where symbol = $1), - (select id from currency where symbol = $2) +with upsert_base as ( + insert into currency(symbol) values($1) + on conflict(symbol) do update set symbol = excluded.symbol + returning id +), +upsert_quote as ( + insert into currency(symbol) values($2) + on conflict(symbol) do update set symbol = excluded.symbol + returning id +), +ins as ( + insert into instrument(base_currency_id, quoted_currency_id, is_global) + select upsert_base.id, upsert_quote.id, false + from upsert_base, upsert_quote + on conflict (base_currency_id, quoted_currency_id) do nothing + returning id ) -returning id` +select coalesce( + (select id from ins), + (select i.id from instrument i + where i.base_currency_id = (select id from upsert_base) + and i.quoted_currency_id = (select id from upsert_quote)) +)` func (p *Postgresql) CreateInstrument(ctx context.Context, instrument *entities.Instrument) (entities.InstrumentID, error) { var id entities.InstrumentID - err := p.db.QueryRow(ctx, createInstrumentQuery, instrument.BaseCurrency, instrument.QuoteCurrency).Scan(&id) if err != nil { return "", fmt.Errorf("failed to exec createInstrumentQuery: %w", err) } - return id, nil } + +const addUserInstrumentQuery = ` +insert into user_instrument(user_id, instrument_id) +values ($1, $2) +on conflict (user_id, instrument_id) do nothing` + +func (p *Postgresql) AddUserInstrument(ctx context.Context, userID entities.UserID, instrumentID entities.InstrumentID) error { + _, err := p.db.Exec(ctx, addUserInstrumentQuery, userID, instrumentID) + if err != nil { + return fmt.Errorf("failed to exec addUserInstrumentQuery: %w", err) + } + return nil +} + +const removeUserInstrumentQuery = ` +delete from user_instrument where user_id = $1 and instrument_id = $2` + +func (p *Postgresql) RemoveUserInstrument(ctx context.Context, userID entities.UserID, instrumentID entities.InstrumentID) error { + _, err := p.db.Exec(ctx, removeUserInstrumentQuery, userID, instrumentID) + if err != nil { + return fmt.Errorf("failed to exec removeUserInstrumentQuery: %w", err) + } + return nil +} diff --git a/internal/repository/postgresql/migrations/000004_user_instruments.down.sql b/internal/repository/postgresql/migrations/000004_user_instruments.down.sql new file mode 100644 index 0000000..55e93f2 --- /dev/null +++ b/internal/repository/postgresql/migrations/000004_user_instruments.down.sql @@ -0,0 +1,2 @@ +drop table if exists user_instrument; +alter table instrument drop column if exists is_global; diff --git a/internal/repository/postgresql/migrations/000004_user_instruments.up.sql b/internal/repository/postgresql/migrations/000004_user_instruments.up.sql new file mode 100644 index 0000000..9e53e5d --- /dev/null +++ b/internal/repository/postgresql/migrations/000004_user_instruments.up.sql @@ -0,0 +1,8 @@ +alter table instrument add column is_global bool not null default false; +update instrument set is_global = true; + +create table user_instrument ( + user_id uuid references users(id) not null, + instrument_id uuid references instrument(id) not null, + primary key (user_id, instrument_id) +); diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 98166ea..b346467 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -13,8 +13,16 @@ type Storage interface { UserByID(ctx context.Context, id entities.UserID) (*entities.User, error) UserByTelegramID(ctx context.Context, tgID entities.TelegramID) (*entities.User, error) - InstrumentList(ctx context.Context, offset, limit int) ([]entities.Instrument, error) + // InstrumentList returns instruments visible to userID: global ones plus any + // the user has explicitly added. + InstrumentList(ctx context.Context, userID entities.UserID, offset, limit int) ([]entities.Instrument, error) + // CreateInstrument upserts the instrument (and its currencies) and returns the ID, + // whether the row was just created or already existed. CreateInstrument(ctx context.Context, instrument *entities.Instrument) (entities.InstrumentID, error) + // AddUserInstrument links an instrument to a user (idempotent). + AddUserInstrument(ctx context.Context, userID entities.UserID, instrumentID entities.InstrumentID) error + // RemoveUserInstrument removes a user's link to a non-global instrument. + RemoveUserInstrument(ctx context.Context, userID entities.UserID, instrumentID entities.InstrumentID) error SaveAlert(ctx context.Context, alert *entities.Alert) (entities.AlertID, error) AllActiveAlerts(ctx context.Context) ([]entities.Alert, error) diff --git a/internal/usecase/instrument.go b/internal/usecase/instrument.go index 40b0242..8521eb4 100644 --- a/internal/usecase/instrument.go +++ b/internal/usecase/instrument.go @@ -3,12 +3,13 @@ package usecase import ( "context" "fmt" + "strings" "gitea.computernetthings.ru/yash/crypto_alert_bot/internal/entities" ) -func (uc *Usecase) InstrumentList(ctx context.Context, offset, limit int) ([]entities.Instrument, error) { - instruments, err := uc.storage.InstrumentList(ctx, offset, limit) +func (uc *Usecase) InstrumentList(ctx context.Context, userID entities.UserID, offset, limit int) ([]entities.Instrument, error) { + instruments, err := uc.storage.InstrumentList(ctx, userID, offset, limit) if err != nil { uc.log.Error("failed to list instruments", "offset", offset, "limit", limit, "err", err) return nil, fmt.Errorf("failed to list instruments: %w", err) @@ -26,3 +27,34 @@ func (uc *Usecase) CreateInstrument(ctx context.Context, instrument *entities.In return id, nil } + +func (uc *Usecase) RemoveUserInstrument(ctx context.Context, userID entities.UserID, instrumentID entities.InstrumentID) error { + if err := uc.storage.RemoveUserInstrument(ctx, userID, instrumentID); err != nil { + uc.log.Error("failed to remove user instrument", "user_id", userID, "instrument_id", instrumentID, "err", err) + return fmt.Errorf("failed to remove user instrument: %w", err) + } + return nil +} + +// AddUserInstrument ensures the instrument exists in the DB (creating it if needed) +// and links it to the given user. Returns the instrument with its ID filled in. +func (uc *Usecase) AddUserInstrument(ctx context.Context, userID entities.UserID, base, quote string) (*entities.Instrument, error) { + base = strings.ToUpper(strings.TrimSpace(base)) + quote = strings.ToUpper(strings.TrimSpace(quote)) + + instr := &entities.Instrument{BaseCurrency: base, QuoteCurrency: quote} + + id, err := uc.storage.CreateInstrument(ctx, instr) + if err != nil { + uc.log.Error("failed to upsert instrument", "base", base, "quote", quote, "err", err) + return nil, fmt.Errorf("failed to upsert instrument: %w", err) + } + instr.ID = id + + if err := uc.storage.AddUserInstrument(ctx, userID, id); err != nil { + uc.log.Error("failed to add user instrument", "user_id", userID, "instrument_id", id, "err", err) + return nil, fmt.Errorf("failed to add user instrument: %w", err) + } + + return instr, nil +}