discord-tweeter/cmd/database.go

172 lines
3.8 KiB
Go

package cmd
import (
"errors"
"fmt"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
ts "github.com/n0madic/twitter-scraper"
"strconv"
)
const (
SqliteSchema = `
CREATE TABLE IF NOT EXISTS tweet (
tweet_id INTEGER PRIMARY KEY AUTOINCREMENT,
snowflake SQLITE_UINT64_TYPE NOT NULL UNIQUE,
channel VARCHAR(15) NOT NULL,
timestamp SQLITE_INT64_TYPE NOT NULL
);
`
KeepTweets int = 10 // How many tweets to keep in database before pruning
)
type Tweet struct {
TweetId int `db:"tweet_id"`
Snowflake uint64 `db:"snowflake"`
Channel string `db:"channel"`
Timestamp int64 `db:"timestamp"`
}
type Database struct {
*sqlx.DB
}
func NewDatabase(driver string, connectString string) (*Database, error) {
var connection *sqlx.DB
var err error
switch driver {
case "sqlite3":
connection, err = sqlx.Connect(driver, "file:"+connectString+"?cache=shared")
if err != nil {
return nil, err
}
connection.SetMaxOpenConns(1)
if _, err = connection.Exec(SqliteSchema); err != nil {
return nil, err
}
default:
return nil, errors.New(fmt.Sprintf("Database driver %s not supported right now!", driver))
}
return &Database{connection}, err
}
func (db *Database) GetNewestTweet(channel string) (*Tweet, error) {
tweet := Tweet{}
err := db.Get(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC LIMIT 1", channel)
if err != nil {
return nil, err
}
return &tweet, nil
}
func (db *Database) GetTweets(channel string) ([]*Tweet, error) {
tweet := []*Tweet{}
err := db.Select(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC", channel)
if err != nil {
return nil, err
}
return tweet, nil
}
func (db *Database) ContainsTweet(channel string, tweet *ts.Tweet) (bool, error) {
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
if err != nil {
return false, err
}
t := Tweet{}
rows, err := db.Queryx("SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC", channel)
if err != nil {
return false, err
}
for rows.Next() {
err := rows.StructScan(&t)
if err != nil {
return false, err
}
if t.Snowflake == snowflake {
return true, nil
}
}
return false, nil
}
func (db *Database) InsertTweet(channel string, tweet *ts.Tweet) error {
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
if err != nil {
return err
}
_, dberr := db.NamedExec("INSERT INTO tweet (snowflake, channel, timestamp) VALUES (:snowflake, :channel, :timestamp)", &Tweet{0, snowflake, channel, tweet.Timestamp})
if dberr != nil {
return err
}
return nil
}
func (db *Database) PruneOldestTweets(channel string) error {
var count int
err := db.Get(&count, "SELECT COUNT(*) FROM tweet WHERE channel=$1", channel)
if err != nil {
return err
}
if count > KeepTweets {
tx, err := db.Beginx()
if err != nil {
tx.Rollback()
return err
}
rows, err := tx.Queryx("SELECT tweet_id from tweet WHERE channel=$1 ORDER by timestamp ASC, snowflake ASC LIMIT $2", channel, count-KeepTweets)
if err != nil {
tx.Rollback()
return err
}
for rows.Next() {
var i int
err = rows.Scan(&i)
if err != nil {
tx.Rollback()
return err
}
_, err = tx.Exec("DELETE FROM tweet WHERE tweet_id=$1", i)
if err != nil {
tx.Rollback()
return err
}
}
tx.Commit()
}
return nil
}
func FromTweet(channel string, tweet *ts.Tweet) (*Tweet, error) {
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
if err != nil {
return nil, err
}
return &Tweet{0, snowflake, channel, tweet.Timestamp}, nil
}
func (t *Tweet) EqualsTweet(tweet *ts.Tweet) bool {
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
if err != nil {
return false
}
return t.Snowflake == snowflake
}
func (t *Tweet) Equals(tweet *Tweet) bool {
return t.Snowflake == tweet.Snowflake
}