From 608561ab3870e3688352c093c25e7ccfbd22731c Mon Sep 17 00:00:00 2001 From: yash Date: Wed, 25 Feb 2026 23:52:21 +0300 Subject: [PATCH] alert service --- internal/entities/alert.go | 8 ++ internal/repository/postgresql/alert.go | 64 ++++++++- .../migrations/000001_init.down.sql | 1 + .../postgresql/migrations/000001_init.up.sql | 5 +- internal/repository/repository.go | 2 + internal/service/alerter/alerter.go | 131 ++++++++++++++++++ internal/service/alerter/cache.go | 71 ++++++++++ internal/usecase/alert.go | 9 ++ 8 files changed, 283 insertions(+), 8 deletions(-) create mode 100644 internal/service/alerter/alerter.go create mode 100644 internal/service/alerter/cache.go diff --git a/internal/entities/alert.go b/internal/entities/alert.go index 0d62e53..9838cb1 100644 --- a/internal/entities/alert.go +++ b/internal/entities/alert.go @@ -4,9 +4,17 @@ import "github.com/shopspring/decimal" type AlertID string +type AlertCondition string + +const ( + AlertConditionAbove AlertCondition = "above" // trigger when price rises to target + AlertConditionBelow AlertCondition = "below" // trigger when price drops to target +) + type Alert struct { ID AlertID UserID UserID Price decimal.Decimal + Condition AlertCondition Instrument Instrument } diff --git a/internal/repository/postgresql/alert.go b/internal/repository/postgresql/alert.go index c6c7baa..0dd9c45 100644 --- a/internal/repository/postgresql/alert.go +++ b/internal/repository/postgresql/alert.go @@ -9,14 +9,14 @@ import ( ) const saveAlertQuery = ` -insert into alert(user_id, instrument_id, price) -values ($1, $2, $3) +insert into alert(user_id, instrument_id, price, condition) +values ($1, $2, $3, $4) returning id` func (p *Postgresql) SaveAlert(ctx context.Context, alert *entities.Alert) (entities.AlertID, error) { var id entities.AlertID - err := p.db.QueryRow(ctx, saveAlertQuery, alert.UserID, alert.Instrument.ID, alert.Price).Scan(&id) + err := p.db.QueryRow(ctx, saveAlertQuery, alert.UserID, alert.Instrument.ID, alert.Price.String(), alert.Condition).Scan(&id) if err != nil { return "", fmt.Errorf("failed to exec saveAlertQuery: %w", err) } @@ -24,8 +24,47 @@ func (p *Postgresql) SaveAlert(ctx context.Context, alert *entities.Alert) (enti return id, nil } +const allActiveAlertsQuery = ` +select a.id, a.user_id, a.price, a.condition, i.id, c_base.symbol, c_quote.symbol +from alert a +join instrument i on i.id = a.instrument_id +join currency c_base on c_base.id = i.base_currency_id +join currency c_quote on c_quote.id = i.quoted_currency_id +where a.active = true +order by a.id` + +func (p *Postgresql) AllActiveAlerts(ctx context.Context) ([]entities.Alert, error) { + rows, err := p.db.Query(ctx, allActiveAlertsQuery) + if err != nil { + return nil, fmt.Errorf("failed to exec allActiveAlertsQuery: %w", err) + } + defer rows.Close() + + var alerts []entities.Alert + for rows.Next() { + var alert entities.Alert + var priceStr string + + if err := rows.Scan( + &alert.ID, &alert.UserID, &priceStr, &alert.Condition, + &alert.Instrument.ID, &alert.Instrument.BaseCurrency, &alert.Instrument.QuoteCurrency, + ); err != nil { + return nil, fmt.Errorf("failed to scan alert row: %w", err) + } + + alert.Price, err = decimal.NewFromString(priceStr) + if err != nil { + return nil, fmt.Errorf("failed to parse alert price: %w", err) + } + + alerts = append(alerts, alert) + } + + return alerts, nil +} + const alertByIDQuery = ` -select a.id, a.user_id, a.price, i.id, c_base.symbol, c_quote.symbol +select a.id, a.user_id, a.price, a.condition, i.id, c_base.symbol, c_quote.symbol from alert a join instrument i on i.id = a.instrument_id join currency c_base on c_base.id = i.base_currency_id @@ -37,7 +76,7 @@ func (p *Postgresql) AlertByID(ctx context.Context, id entities.AlertID) (*entit var priceStr string err := p.db.QueryRow(ctx, alertByIDQuery, id).Scan( - &alert.ID, &alert.UserID, &priceStr, + &alert.ID, &alert.UserID, &priceStr, &alert.Condition, &alert.Instrument.ID, &alert.Instrument.BaseCurrency, &alert.Instrument.QuoteCurrency, ) if err != nil { @@ -53,7 +92,7 @@ func (p *Postgresql) AlertByID(ctx context.Context, id entities.AlertID) (*entit } const alertsByUserIDQuery = ` -select a.id, a.user_id, a.price, i.id, c_base.symbol, c_quote.symbol +select a.id, a.user_id, a.price, a.condition, i.id, c_base.symbol, c_quote.symbol from alert a join instrument i on i.id = a.instrument_id join currency c_base on c_base.id = i.base_currency_id @@ -75,7 +114,7 @@ func (p *Postgresql) AlertsByUserID(ctx context.Context, userID entities.UserID, var priceStr string if err := rows.Scan( - &alert.ID, &alert.UserID, &priceStr, + &alert.ID, &alert.UserID, &priceStr, &alert.Condition, &alert.Instrument.ID, &alert.Instrument.BaseCurrency, &alert.Instrument.QuoteCurrency, ); err != nil { return nil, fmt.Errorf("failed to scan alert row: %w", err) @@ -103,6 +142,17 @@ func (p *Postgresql) DeleteAlert(ctx context.Context, id entities.AlertID) error return nil } +const disableAlertQuery = "update alert set active = false where id = $1" + +func (p *Postgresql) DisableAlert(ctx context.Context, id entities.AlertID) error { + _, err := p.db.Exec(ctx, disableAlertQuery, id) + if err != nil { + return fmt.Errorf("failed to exec disableAlertQuery: %w", err) + } + + return nil +} + const updateAlertPriceQuery = "update alert set price = $2 where id = $1" func (p *Postgresql) UpdateAlertPrice(ctx context.Context, id entities.AlertID, price decimal.Decimal) error { diff --git a/internal/repository/postgresql/migrations/000001_init.down.sql b/internal/repository/postgresql/migrations/000001_init.down.sql index a0d959b..fbb669b 100644 --- a/internal/repository/postgresql/migrations/000001_init.down.sql +++ b/internal/repository/postgresql/migrations/000001_init.down.sql @@ -1,4 +1,5 @@ drop table if exists alert; +drop type alert_condition; drop table if exists instrument; drop table if exists currency; drop table if exists users; diff --git a/internal/repository/postgresql/migrations/000001_init.up.sql b/internal/repository/postgresql/migrations/000001_init.up.sql index 6bba2f1..35ab8d8 100644 --- a/internal/repository/postgresql/migrations/000001_init.up.sql +++ b/internal/repository/postgresql/migrations/000001_init.up.sql @@ -17,12 +17,15 @@ create table if not exists instrument ( UNIQUE (base_currency_id, quoted_currency_id) ); +create type alert_condition as enum ('above', 'below'); + create table if not exists alert ( id uuid primary key not null default gen_random_uuid(), user_id uuid references users(id) not null, instrument_id uuid references instrument(id) not null, price text not null, - active bool not null default true + active bool not null default true, + condition alert_condition not null ); insert into currency(symbol) values ('USDT'), ('BTC'), ('ETH'), ('SOL'); diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 0bc0ac5..40fe972 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -16,8 +16,10 @@ type Storage interface { CreateInstrument(ctx context.Context, instrument *entities.Instrument) (entities.InstrumentID, error) SaveAlert(ctx context.Context, alert *entities.Alert) (entities.AlertID, error) + AllActiveAlerts(ctx context.Context) ([]entities.Alert, error) AlertByID(ctx context.Context, id entities.AlertID) (*entities.Alert, error) AlertsByUserID(ctx context.Context, userID entities.UserID, offset, limit int) ([]entities.Alert, error) DeleteAlert(ctx context.Context, id entities.AlertID) error + DisableAlert(ctx context.Context, id entities.AlertID) error UpdateAlertPrice(ctx context.Context, id entities.AlertID, price decimal.Decimal) error } diff --git a/internal/service/alerter/alerter.go b/internal/service/alerter/alerter.go new file mode 100644 index 0000000..3a96ac2 --- /dev/null +++ b/internal/service/alerter/alerter.go @@ -0,0 +1,131 @@ +package alerter + +import ( + "context" + "fmt" + "log/slog" + "time" + + "gitea.computernetthings.ru/yash/crypto_alert_bot/internal/entities" + "gitea.computernetthings.ru/yash/crypto_alert_bot/internal/provider" + "github.com/shopspring/decimal" +) + + +type Notifier interface { + NotifyAlert(ctx context.Context, userID entities.UserID, alert *entities.Alert, currentPrice decimal.Decimal) error +} + +type Storage interface { + AllActiveAlerts(ctx context.Context) ([]entities.Alert, error) + DisableAlert(ctx context.Context, id entities.AlertID) error +} + +type Alerter struct { + log *slog.Logger + cache *alertsCache + priceProvider provider.Provider + notifier Notifier + storage Storage +} + +const interval = time.Minute + +func New(log *slog.Logger, priceProvider provider.Provider, notifier Notifier, storage Storage) *Alerter { + return &Alerter{ + log: log, + cache: newCache(), + priceProvider: priceProvider, + notifier: notifier, + storage: storage, + } +} + +func (a *Alerter) LoadAlerts(ctx context.Context) error { + alerts, err := a.storage.AllActiveAlerts(ctx) + if err != nil { + return fmt.Errorf("failed to load alerts: %w", err) + } + + for i := range alerts { + a.cache.Add(&alerts[i]) + } + + a.log.Info("alerts loaded", "count", len(alerts)) + return nil +} + +func (a *Alerter) AddAlert(alert *entities.Alert) { + a.cache.Add(alert) +} + +func (a *Alerter) RemoveAlert(id entities.AlertID) { + a.cache.Remove(id) +} + +func (a *Alerter) Run(ctx context.Context) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + + a.log.Info("start checking alerts") + + if err := a.checkAlerts(ctx); err != nil { + a.log.Error("failed to check alerts", "err", err) + continue + } + + a.log.Info("alerts checked") + } + }() +} + +// TODO: parallel checking for different instruments. + +func (a *Alerter) checkAlerts(ctx context.Context) error { + instruments := a.cache.Instruments() + + for _, instrument := range instruments { + price, err := a.priceProvider.Price(ctx, instrument) + if err != nil { + a.log.Error("failed to get price", "instrument", instrument.ID, "err", err) + continue + } + + alerts := a.cache.AlertsByInstrument(instrument.ID) + for _, alert := range alerts { + switch alert.Condition { + case entities.AlertConditionAbove: + if price.Ask.GreaterThanOrEqual(alert.Price) { + a.triggerAlert(ctx, alert, price.Ask) + } + case entities.AlertConditionBelow: + if price.Bid.LessThanOrEqual(alert.Price) { + a.triggerAlert(ctx, alert, price.Bid) + } + } + } + } + + return nil +} + +func (a *Alerter) triggerAlert(ctx context.Context, alert *entities.Alert, currentPrice decimal.Decimal) { + if err := a.notifier.NotifyAlert(ctx, alert.UserID, alert, currentPrice); err != nil { + a.log.Error("failed to notify alert", "alert_id", alert.ID, "err", err) + return + } + + a.cache.Remove(alert.ID) + + if err := a.storage.DisableAlert(ctx, alert.ID); err != nil { + a.log.Error("failed to disable alert in db", "alert_id", alert.ID, "err", err) + } +} diff --git a/internal/service/alerter/cache.go b/internal/service/alerter/cache.go new file mode 100644 index 0000000..d7de5c2 --- /dev/null +++ b/internal/service/alerter/cache.go @@ -0,0 +1,71 @@ +package alerter + +import ( + "sync" + + "gitea.computernetthings.ru/yash/crypto_alert_bot/internal/entities" +) + +type alertsCache struct { + mu sync.RWMutex + byID map[entities.AlertID]*entities.Alert + byInstrument map[entities.InstrumentID]map[entities.AlertID]*entities.Alert +} + +func newCache() *alertsCache { + return &alertsCache{ + byID: make(map[entities.AlertID]*entities.Alert), + byInstrument: make(map[entities.InstrumentID]map[entities.AlertID]*entities.Alert), + } +} + +func (c *alertsCache) Add(a *entities.Alert) { + c.mu.Lock() + defer c.mu.Unlock() + + c.byID[a.ID] = a + + if _, ok := c.byInstrument[a.Instrument.ID]; !ok { + c.byInstrument[a.Instrument.ID] = make(map[entities.AlertID]*entities.Alert) + } + c.byInstrument[a.Instrument.ID][a.ID] = a +} + +func (c *alertsCache) Remove(id entities.AlertID) { + c.mu.Lock() + defer c.mu.Unlock() + + a, ok := c.byID[id] + if !ok { + return + } + + delete(c.byID, id) + delete(c.byInstrument[a.Instrument.ID], id) +} + +func (c *alertsCache) AlertsByInstrument(id entities.InstrumentID) []*entities.Alert { + c.mu.RLock() + defer c.mu.RUnlock() + + alerts := c.byInstrument[id] + result := make([]*entities.Alert, 0, len(alerts)) + for _, a := range alerts { + result = append(result, a) + } + return result +} + +func (c *alertsCache) Instruments() []entities.Instrument { + c.mu.RLock() + defer c.mu.RUnlock() + + instruments := make([]entities.Instrument, 0, len(c.byInstrument)) + for _, alerts := range c.byInstrument { + for _, a := range alerts { + instruments = append(instruments, a.Instrument) + break + } + } + return instruments +} diff --git a/internal/usecase/alert.go b/internal/usecase/alert.go index f7cf53b..83990ad 100644 --- a/internal/usecase/alert.go +++ b/internal/usecase/alert.go @@ -55,3 +55,12 @@ func (uc *Usecase) UpdateAlertPrice(ctx context.Context, alertID entities.AlertI return nil } + +func (uc *Usecase) DisableAlert(ctx context.Context, alertID entities.AlertID) error { + if err := uc.storage.DisableAlert(ctx, alertID); err != nil { + uc.log.Error("failed to disable alert", "alert_id", alertID, "err", err) + return fmt.Errorf("failed to disable alert: %w", err) + } + + return nil +}