171 lines
3.8 KiB
Go
171 lines
3.8 KiB
Go
package cmd
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/jmoiron/sqlx"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
ts "github.com/imperatrona/twitter-scraper"
|
|
"strconv"
|
|
)
|
|
|
|
const (
|
|
SqliteSchema = `
|
|
CREATE TABLE IF NOT EXISTS tweet (
|
|
tweet_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
snowflake SQLITE_UINT64_TYPE NOT NULL UNIQUE,
|
|
channel VARCHAR(15) NOT NULL,
|
|
timestamp SQLITE_INT64_TYPE NOT NULL
|
|
);
|
|
`
|
|
KeepTweets int = 10 // How many tweets to keep in database before pruning
|
|
)
|
|
|
|
type Tweet struct {
|
|
TweetId int `db:"tweet_id"`
|
|
Snowflake uint64 `db:"snowflake"`
|
|
Channel string `db:"channel"`
|
|
Timestamp int64 `db:"timestamp"`
|
|
}
|
|
|
|
type Database struct {
|
|
*sqlx.DB
|
|
}
|
|
|
|
func NewDatabase(driver string, connectString string) (*Database, error) {
|
|
var connection *sqlx.DB
|
|
var err error
|
|
|
|
switch driver {
|
|
case "sqlite3":
|
|
connection, err = sqlx.Connect(driver, "file:"+connectString+"?cache=shared")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
connection.SetMaxOpenConns(1)
|
|
|
|
if _, err = connection.Exec(SqliteSchema); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
default:
|
|
return nil, errors.New(fmt.Sprintf("Database driver %s not supported right now!", driver))
|
|
}
|
|
|
|
return &Database{connection}, err
|
|
}
|
|
|
|
func (db *Database) GetNewestTweet(channel string) (*Tweet, error) {
|
|
tweet := Tweet{}
|
|
err := db.Get(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC LIMIT 1", channel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &tweet, nil
|
|
}
|
|
|
|
func (db *Database) GetTweets(channel string) ([]*Tweet, error) {
|
|
tweet := []*Tweet{}
|
|
err := db.Select(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC", channel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return tweet, nil
|
|
}
|
|
|
|
func (db *Database) ContainsTweet(channel string, tweet *ts.Tweet) (bool, error) {
|
|
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
t := Tweet{}
|
|
rows, err := db.Queryx("SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC", channel)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
for rows.Next() {
|
|
err := rows.StructScan(&t)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if t.Snowflake == snowflake {
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func (db *Database) InsertTweet(channel string, tweet *ts.Tweet) error {
|
|
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, dberr := db.NamedExec("INSERT INTO tweet (snowflake, channel, timestamp) VALUES (:snowflake, :channel, :timestamp)", &Tweet{0, snowflake, channel, tweet.Timestamp})
|
|
if dberr != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Database) PruneOldestTweets(channel string) error {
|
|
var count int
|
|
err := db.Get(&count, "SELECT COUNT(*) FROM tweet WHERE channel=$1", channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if count > KeepTweets {
|
|
tx, err := db.Beginx()
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
rows, err := tx.Queryx("SELECT tweet_id from tweet WHERE channel=$1 ORDER by timestamp ASC, snowflake ASC LIMIT $2", channel, count-KeepTweets)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
for rows.Next() {
|
|
var i int
|
|
err = rows.Scan(&i)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
_, err = tx.Exec("DELETE FROM tweet WHERE tweet_id=$1", i)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
}
|
|
tx.Commit()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func FromTweet(channel string, tweet *ts.Tweet) (*Tweet, error) {
|
|
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Tweet{0, snowflake, channel, tweet.Timestamp}, nil
|
|
}
|
|
|
|
func (t *Tweet) EqualsTweet(tweet *ts.Tweet) bool {
|
|
snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return t.Snowflake == snowflake
|
|
}
|
|
|
|
func (t *Tweet) Equals(tweet *Tweet) bool {
|
|
return t.Snowflake == tweet.Snowflake
|
|
}
|