package db

import (
	"fmt"
	"strconv"

	ts "github.com/imperatrona/twitter-scraper"
	"github.com/jmoiron/sqlx"
	_ "github.com/mattn/go-sqlite3"
)

const (
	SqliteSchema = `
CREATE TABLE IF NOT EXISTS tweet (
	tweet_id   INTEGER PRIMARY KEY AUTOINCREMENT,
    snowflake  SQLITE_UINT64_TYPE NOT NULL UNIQUE,
    channel    VARCHAR(15) NOT NULL,
    timestamp  SQLITE_INT64_TYPE NOT NULL
);
`
	KeepTweets int = 10 // How many tweets to keep in database before pruning
)

type Tweet struct {
	TweetId   int    `db:"tweet_id"`
	Snowflake uint64 `db:"snowflake"`
	Channel   string `db:"channel"`
	Timestamp int64  `db:"timestamp"`
}

type Database struct {
	*sqlx.DB
}

func New(driver string, connectString string) (Database, error) {
	var connection *sqlx.DB
	var err error

	switch driver {
	case "sqlite3":
		connection, err = sqlx.Connect(driver, "file:"+connectString+"?cache=shared")
		if err != nil {
			return Database{}, err
		}
		connection.SetMaxOpenConns(1)

		if _, err = connection.Exec(SqliteSchema); err != nil {
			return Database{}, err
		}

	default:
		return Database{}, fmt.Errorf("database driver %s not supported right now", driver)
	}

	return Database{connection}, err
}

func (db Database) GetNewestTweet(channel string) (*Tweet, error) {
	tweet := Tweet{}
	err := db.Get(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC LIMIT 1", channel)
	if err != nil {
		return nil, err
	}
	return &tweet, nil
}

func (db Database) GetTweets(channel string) ([]*Tweet, error) {
	tweet := []*Tweet{}
	err := db.Select(&tweet, "SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC", channel)
	if err != nil {
		return nil, err
	}
	return tweet, nil
}

func (db Database) ContainsTweet(channel string, tweet *ts.Tweet) (bool, error) {
	snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
	if err != nil {
		return false, err
	}

	t := Tweet{}
	rows, err := db.Queryx("SELECT * FROM tweet WHERE channel=$1 ORDER BY timestamp DESC, snowflake DESC", channel)
	if err != nil {
		return false, err
	}
	for rows.Next() {
		err := rows.StructScan(&t)
		if err != nil {
			return false, err
		}
		if t.Snowflake == snowflake {
			return true, nil
		}
	}

	return false, nil
}

func (db Database) InsertTweet(channel string, tweet *ts.Tweet) error {
	snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
	if err != nil {
		return err
	}

	_, dberr := db.NamedExec("INSERT INTO tweet (snowflake, channel, timestamp) VALUES (:snowflake, :channel, :timestamp)", &Tweet{0, snowflake, channel, tweet.Timestamp})
	if dberr != nil {
		return err
	}

	return nil
}

func (db Database) PruneOldestTweets(channel string) error {
	var count int
	err := db.Get(&count, "SELECT COUNT(*) FROM tweet WHERE channel=$1", channel)
	if err != nil {
		return err
	}

	if count > KeepTweets {
		tx, err := db.Beginx()
		if err != nil {
			tx.Rollback()
			return err
		}
		rows, err := tx.Queryx("SELECT tweet_id from tweet WHERE channel=$1 ORDER by timestamp ASC, snowflake ASC LIMIT $2", channel, count-KeepTweets)
		if err != nil {
			tx.Rollback()
			return err
		}
		for rows.Next() {
			var i int
			err = rows.Scan(&i)
			if err != nil {
				tx.Rollback()
				return err
			}
			_, err = tx.Exec("DELETE FROM tweet WHERE tweet_id=$1", i)
			if err != nil {
				tx.Rollback()
				return err
			}
		}
		tx.Commit()
	}

	return nil
}

func FromTweet(channel string, tweet *ts.Tweet) (*Tweet, error) {
	snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
	if err != nil {
		return nil, err
	}

	return &Tweet{0, snowflake, channel, tweet.Timestamp}, nil
}

func (t Tweet) EqualsTweet(tweet *ts.Tweet) bool {
	snowflake, err := strconv.ParseUint(tweet.ID, 10, 64)
	if err != nil {
		return false
	}

	return t.Snowflake == snowflake
}

func (t Tweet) Equals(tweet *Tweet) bool {
	return t.Snowflake == tweet.Snowflake
}