package localserver import ( "encoding/binary" "os" "path/filepath" "sync" "github.com/dgraph-io/badger/v3" ) type NodeChatTrackingDB struct { ChatID string db func(cb func(*badger.DB) (err error)) (err error) lock *sync.RWMutex } func NewNodeChatTracking(chatId, initiator, target string) (*NodeChatTrackingDB, error) { if _, dirErr := os.ReadDir(filepath.Join(dataPath, "data", "chats", chatId, "__tracking__")); os.IsNotExist(dirErr) { _ = os.MkdirAll(filepath.Join(dataPath, "data", "chats", chatId, "__tracking__"), 0700) } db := func(f func(*badger.DB) (err error)) (err error) { db, err := badger.Open(badger.DefaultOptions(filepath.Join(dataPath, "data", "chats", chatId, "__tracking__")).WithLogger(dbLogger)) if err != nil { return } defer db.Close() err = f(db) return } lock := new(sync.RWMutex) if err := db(func(d *badger.DB) (err error) { err = d.Update(func(txn *badger.Txn) error { b := make([]byte, bufferSize) binary.BigEndian.PutUint64(b, 0) for _, member := range [2]string{initiator, target} { if _, rerr := txn.Get([]byte(member)); rerr == badger.ErrKeyNotFound { _ = txn.Set([]byte(member), b) } else if rerr != nil { return rerr } } return nil }) if err != nil { return } return }); err != nil { return nil, err } return &NodeChatTrackingDB{chatId, db, lock}, nil } func (nctdb *NodeChatTrackingDB) Initialize(lastIndex uint64, users ...string) (err error) { for _, user := range users { if err = nctdb.SetUserLastIndex(user, lastIndex); err != nil { return } } return } func (nctdb *NodeChatTrackingDB) RevertTrackingLastIndex(lastIndex uint64) (err error) { nctdb.lock.Lock() defer nctdb.lock.Unlock() err = nctdb.db(func(d *badger.DB) (err error) { keys := [][]byte{} err = d.View(func(txn *badger.Txn) error { opt := badger.DefaultIteratorOptions it := txn.NewIterator(opt) defer it.Close() for it.Rewind(); it.Valid(); it.Next() { item := it.Item() if err = item.Value(func(val []byte) error { li := binary.BigEndian.Uint64(val) if li >= lastIndex { keys = append(keys, item.Key()) } return nil }); err != nil { return err } } return nil }) if err != nil { return } err = d.Update(func(txn *badger.Txn) error { for _, key := range keys { b := make([]byte, bufferSize) binary.BigEndian.PutUint64(b, lastIndex) if updateErr := txn.Set(key, b); updateErr != nil { return updateErr } } return nil }) return }) return } func (nctdb *NodeChatTrackingDB) SetUserLastIndex(userId string, lastIndex uint64) (err error) { nctdb.lock.Lock() defer nctdb.lock.Unlock() err = nctdb.db(func(d *badger.DB) (err error) { err = d.Update(func(txn *badger.Txn) error { b := make([]byte, bufferSize) binary.BigEndian.PutUint64(b, lastIndex) updateErr := txn.Set([]byte(userId), b) return updateErr }) return }) return } func (nctdb *NodeChatTrackingDB) GetUserLastIndex(userId string) (index uint, err error) { nctdb.lock.Lock() defer nctdb.lock.Unlock() err = nctdb.db(func(d *badger.DB) (err error) { err = d.Update(func(txn *badger.Txn) error { item, rerr := txn.Get([]byte(userId)) if rerr != nil { return rerr } _ = item.Value(func(val []byte) error { index = uint(binary.BigEndian.Uint64(val)) return nil }) return nil }) return }) return } func (nctdb *NodeChatTrackingDB) DeleteUserTracking(userId string) (err error) { nctdb.lock.Lock() defer nctdb.lock.Unlock() err = nctdb.db(func(d *badger.DB) (err error) { err = d.Update(func(txn *badger.Txn) error { updateErr := txn.Delete([]byte(userId)) return updateErr }) return }) return }