Zippytal-Node/webrtcCallFileManager.go

312 lines
8.3 KiB
Go

package localserver
import (
"bufio"
"encoding/json"
"fmt"
"io"
"log"
"os"
"path/filepath"
"sync/atomic"
"github.com/pion/webrtc/v3"
)
const (
UPLOAD_INIT = "upload_init"
UPLOAD = "upload"
UPLOAD_DONE = "upload_done"
DOWNLOAD_INIT = "download_init"
DOWNLOAD = "download"
DOWNLOAD_DONE = "download_done"
HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_INIT = "hosted_squad_download_file_response_init"
HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE = "hosted_squad_download_file_response"
HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_END = "hosted_squad_download_file_response_end"
)
const (
bufferedAmountLowThreshold uint64 = 512 * 1024
//maxBufferedAmount uint64 = 1024 * 1024
)
type WebrtcCallFileManager struct {
files map[string]*os.File
l *uint32
}
func NewWebrtcCallFileManager() *WebrtcCallFileManager {
l := uint32(0)
return &WebrtcCallFileManager{
files: make(map[string]*os.File),
l: &l,
}
}
func (w *WebrtcCallFileManager) HandleCallEvent(from string, squadId string, eventId string, payload map[string]interface{}, data []byte, manager *WebRTCCallManager) (err error) {
done, errCh := make(chan struct{}), make(chan error)
go func() {
logger.Println("got an event in call file manager", from, eventId, payload)
switch eventId {
case UPLOAD_INIT:
if _, ok := payload["filename"]; !ok {
errCh <- fmt.Errorf("no field filename in payload")
return
}
if _, ok := payload["filename"].(string); !ok {
errCh <- fmt.Errorf("field filename in payload is not a string")
return
}
if err = w.initUpload(squadId, from, payload["filename"].(string)); err != nil {
errCh <- err
return
}
case UPLOAD:
if _, ok := payload["filename"]; !ok {
errCh <- fmt.Errorf("no field filename in payload")
return
}
if _, ok := payload["filename"].(string); !ok {
errCh <- fmt.Errorf("field filename in payload is not a string")
return
}
if err = w.upload(squadId, from, payload["filename"].(string), data); err != nil {
errCh <- err
return
}
case UPLOAD_DONE:
if _, ok := payload["filename"]; !ok {
errCh <- fmt.Errorf("no field filename in payload")
return
}
if _, ok := payload["filename"].(string); !ok {
errCh <- fmt.Errorf("field filename in payload is not a string")
return
}
if _, ok := payload["targets"]; !ok {
errCh <- fmt.Errorf("no field targets in payload")
return
}
if _, ok := payload["targets"].([]interface{}); !ok {
errCh <- fmt.Errorf("field targets in payload is not a string")
return
}
channels := []*DataChannel{}
manager.DataChannelMapMux.RLock()
for _, target := range payload["targets"].([]interface{}) {
if _, ok := manager.DataChannels[target.(string)]; !ok {
manager.DataChannelMapMux.RUnlock()
errCh <- fmt.Errorf("no corresponding datachannel : %s", target.(string))
return
}
channel := manager.DataChannels[target.(string)]
for {
if atomic.CompareAndSwapUint32(channel.l, 0, 1) {
defer atomic.SwapUint32(channel.l, 0)
break
}
}
channels = append(channels, channel)
}
manager.DataChannelMapMux.RUnlock()
if err = w.uploadDone(squadId, from, payload["filename"].(string), channels); err != nil {
errCh <- err
return
}
case DOWNLOAD:
if _, ok := payload["filename"]; !ok {
errCh <- fmt.Errorf("no field filename in payload")
return
}
if _, ok := payload["filename"].(string); !ok {
errCh <- fmt.Errorf("field filename in payload is not a string")
return
}
if _, ok := payload["peerId"]; !ok {
errCh <- fmt.Errorf("no field peerId in payload")
return
}
if _, ok := payload["peerId"].(string); !ok {
errCh <- fmt.Errorf("field peerId in payload is not a string")
return
}
manager.DataChannelMapMux.RLock()
if _, ok := manager.DataChannels[payload["peerId"].(string)]; !ok {
manager.DataChannelMapMux.RUnlock()
errCh <- fmt.Errorf("no corresponding datachannel")
return
}
channel := manager.DataChannels[payload["peerId"].(string)]
for {
if atomic.CompareAndSwapUint32(channel.l, 0, 1) {
logger.Println("atomic lock unlocked")
defer atomic.SwapUint32(channel.l, 0)
break
}
}
manager.DataChannelMapMux.RUnlock()
if err = w.download(squadId, from, payload["filename"].(string), channel.DataChannel); err != nil {
errCh <- err
return
}
}
done <- struct{}{}
}()
select {
case <-done:
return nil
case err = <-errCh:
return
}
}
func (w *WebrtcCallFileManager) initUpload(squadId string, from string, fileName string) (err error) {
for {
if atomic.CompareAndSwapUint32(w.l, 0, 1) {
defer atomic.SwapUint32(w.l, 0)
if _, dirErr := os.Stat(filepath.Join(dataPath, dataPath, "data", "squads", squadId)); os.IsNotExist(dirErr) {
if err = os.MkdirAll(filepath.Join(dataPath, dataPath, "data", "squads", squadId), 0700); err != nil {
return
}
}
f, fErr := os.Create(filepath.Join(dataPath, dataPath, "data", "squads", squadId, fileName))
if err != nil {
return fErr
}
f.Close()
f, fErr = os.OpenFile(filepath.Join(dataPath, dataPath, "data", "squads", squadId, fileName), os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return fErr
}
w.files[fileName] = f
break
} else {
continue
}
}
return
}
func (w *WebrtcCallFileManager) upload(squadId string, from string, fileName string, data []byte) (err error) {
for {
if atomic.CompareAndSwapUint32(w.l, 0, 1) {
defer atomic.SwapUint32(w.l, 0)
if _, ok := w.files[fileName]; !ok {
err = fmt.Errorf("no open file with name %s", fileName)
return
}
_, err = w.files[fileName].Write(data)
break
} else {
continue
}
}
return
}
func (w *WebrtcCallFileManager) uploadDone(squadId string, from string, fileName string, channels []*DataChannel) (err error) {
for {
if atomic.CompareAndSwapUint32(w.l, 0, 1) {
defer atomic.SwapUint32(w.l, 0)
if _, ok := w.files[fileName]; !ok {
err = fmt.Errorf("no open file with name %s", fileName)
return
}
err = w.files[fileName].Close()
delete(w.files, fileName)
bsInit, jsonErr := json.Marshal(map[string]interface{}{
"type": UPLOAD_DONE,
"from": "server",
"payload": map[string]string{
"path": fileName,
},
})
if jsonErr != nil {
return jsonErr
}
for _, channel := range channels {
if err = channel.DataChannel.SendText(string(bsInit)); err != nil {
return
}
}
break
} else {
continue
}
}
return
}
func (w *WebrtcCallFileManager) download(squadId string, dst string, fileName string, channel *webrtc.DataChannel) (err error) {
logger.Println("got called")
if _, dirErr := os.Stat(filepath.Join(dataPath, dataPath, "data", "squads", squadId, fileName)); os.IsNotExist(dirErr) {
logger.Println("file does not exist :", filepath.Join(dataPath, "data", "squads", squadId, fileName))
return
}
f, err := os.Open(filepath.Join(dataPath, dataPath, "data", "squads", squadId, fileName))
if err != nil {
return
}
defer f.Close()
bsInit, err := json.Marshal(map[string]interface{}{
"type": HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_INIT,
"from": "server",
"payload": map[string]string{
"path": fileName,
},
})
if err != nil {
return
}
if err = channel.SendText(string(bsInit)); err != nil {
return
}
r := bufio.NewReader(f)
buf := make([]byte, 0, 30000)
logger.Println("start reading")
for {
n, readErr := r.Read(buf[:cap(buf)])
buf = buf[:n]
if n == 0 {
if err == nil {
logger.Println("n is 0 weird")
break
}
if err == io.EOF {
break
}
log.Fatal(readErr)
}
bs, jsonErr := json.Marshal(map[string]interface{}{
"type": HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE,
"from": "server",
"payload": map[string]interface{}{
"path": fileName,
"content": buf,
},
})
if jsonErr != nil {
return jsonErr
}
if err = channel.SendText(string(bs)); err != nil {
return
}
}
logger.Println("stop reading")
bs, err := json.Marshal(map[string]interface{}{
"type": HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_END,
"from": "server",
"payload": map[string]string{
"path": fileName,
},
})
if err != nil {
return
}
if err = channel.SendText(string(bs)); err != nil {
return
}
logger.Println("done")
return
}