package cmd import ( "errors" "fmt" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" ts "github.com/n0madic/twitter-scraper" "strconv" ) const ( SqliteSchema = ` CREATE TABLE IF NOT EXISTS tweet ( tweet_id INTEGER PRIMARY KEY AUTOINCREMENT, snowflake SQLITE_UINT64_TYPE NOT NULL UNIQUE, channel VARCHAR(15) NOT NULL, timestamp SQLITE_INT64_TYPE NOT NULL ); ` KeepTweets int = 10 // How many tweets to keep in database before pruning ) type Tweet struct { TweetId int `db:"tweet_id"` Snowflake uint64 `db:"snowflake"` Channel string `db:"channel"` Timestamp int64 `db:"timestamp"` } type Database struct { *sqlx.DB } func NewDatabase(driver string, connectString string) (*Database, error) { var connection *sqlx.DB var err error switch driver { case "sqlite3": connection, err = sqlx.Connect(driver, "file:"+connectString+"?cache=shared") if err != nil { return nil, err } connection.SetMaxOpenConns(1) if _, err = connection.Exec(SqliteSchema); err != nil { return nil, err } default: return nil, errors.New(fmt.Sprintf("Database driver %s not supported right now!", driver)) } return &Database{connection}, err } func (db *Database) GetNewestTweet(channel string) (*Tweet, error) { tweet := Tweet{} err := db.Get(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC LIMIT 1", channel) if err != nil { return nil, err } return &tweet, nil } func (db *Database) GetTweets(channel string) ([]*Tweet, error) { tweet := []*Tweet{} err := db.Select(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC", channel) if err != nil { return nil, err } return tweet, nil } func (db *Database) ContainsTweet(channel string, tweet *ts.Tweet) (bool, error) { snowflake, err := strconv.ParseUint(tweet.ID, 10, 64) if err != nil { return false, err } t := Tweet{} rows, err := db.Queryx("SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC", channel) if err != nil { return false, err } for rows.Next() { err := rows.StructScan(&t) if err != nil { return false, err } if t.Snowflake == snowflake { return true, nil } } return false, nil } func (db *Database) InsertTweet(channel string, tweet *ts.Tweet) error { snowflake, err := strconv.ParseUint(tweet.ID, 10, 64) if err != nil { return err } _, dberr := db.NamedExec("INSERT INTO tweet (snowflake, channel, timestamp) VALUES (:snowflake, :channel, :timestamp)", &Tweet{0, snowflake, channel, tweet.Timestamp}) if dberr != nil { return err } return nil } func (db *Database) PruneOldestTweets(channel string) error { var count int err := db.Get(&count, "SELECT COUNT(*) FROM tweet WHERE channel=$1", channel) if err != nil { return err } if count > KeepTweets { tx, err := db.Beginx() if err != nil { tx.Rollback() return err } rows, err := tx.Queryx("SELECT tweet_id from tweet WHERE channel=$1 ORDER by timestamp ASC LIMIT $2", channel, count-KeepTweets) if err != nil { tx.Rollback() return err } for rows.Next() { var i int err = rows.Scan(&i) if err != nil { tx.Rollback() return err } _, err = tx.Exec("DELETE FROM tweet WHERE tweet_id=$1", i) if err != nil { tx.Rollback() return err } } tx.Commit() } return nil } func FromTweet(channel string, tweet *ts.Tweet) (*Tweet, error) { snowflake, err := strconv.ParseUint(tweet.ID, 10, 64) if err != nil { return nil, err } return &Tweet{0, snowflake, channel, tweet.Timestamp}, nil } func (t *Tweet) EqualsTweet(tweet *ts.Tweet) bool { snowflake, err := strconv.ParseUint(tweet.ID, 10, 64) if err != nil { return false } return t.Snowflake == snowflake } func (t *Tweet) Equals(tweet *Tweet) bool { return t.Snowflake == tweet.Snowflake }