package localserver import ( "bufio" "encoding/json" "fmt" "io" "log" "os" "path/filepath" "sync/atomic" "github.com/pion/webrtc/v3" ) const ( UPLOAD_INIT = "upload_init" UPLOAD = "upload" UPLOAD_DONE = "upload_done" DOWNLOAD_INIT = "download_init" DOWNLOAD = "download" DOWNLOAD_DONE = "download_done" HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_INIT = "hosted_squad_download_file_response_init" HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE = "hosted_squad_download_file_response" HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_END = "hosted_squad_download_file_response_end" ) const ( bufferedAmountLowThreshold uint64 = 512 * 1024 //maxBufferedAmount uint64 = 1024 * 1024 ) type WebrtcCallFileManager struct { files map[string]*os.File l *int32 } func NewWebrtcCallFileManager() *WebrtcCallFileManager { l := int32(0) return &WebrtcCallFileManager{ files: make(map[string]*os.File), l: &l, } } func (w *WebrtcCallFileManager) HandleCallEvent(from string, squadId string, eventId string, payload map[string]interface{}, data []byte, manager *WebRTCCallManager) (err error) { done, errCh := make(chan struct{}), make(chan error) go func() { logger.Println("got an event in call file manager", from, eventId, payload) switch eventId { case UPLOAD_INIT: if _, ok := payload["filename"]; !ok { errCh <- fmt.Errorf("no field filename in payload") return } if _, ok := payload["filename"].(string); !ok { errCh <- fmt.Errorf("field filename in payload is not a string") return } if err = w.initUpload(squadId, from, payload["filename"].(string)); err != nil { errCh <- err return } case UPLOAD: if _, ok := payload["filename"]; !ok { errCh <- fmt.Errorf("no field filename in payload") return } if _, ok := payload["filename"].(string); !ok { errCh <- fmt.Errorf("field filename in payload is not a string") return } if err = w.upload(squadId, from, payload["filename"].(string), data); err != nil { errCh <- err return } case UPLOAD_DONE: if _, ok := payload["filename"]; !ok { errCh <- fmt.Errorf("no field filename in payload") return } if _, ok := payload["filename"].(string); !ok { errCh <- fmt.Errorf("field filename in payload is not a string") return } if _, ok := payload["targets"]; !ok { errCh <- fmt.Errorf("no field targets in payload") return } if _, ok := payload["targets"].([]interface{}); !ok { errCh <- fmt.Errorf("field targets in payload is not a string") return } channels := []*DataChannel{} manager.DataChannelMapMux.RLock() for _, target := range payload["targets"].([]interface{}) { if _, ok := manager.DataChannels[target.(string)]; !ok { manager.DataChannelMapMux.RUnlock() errCh <- fmt.Errorf("no corresponding datachannel : %s", target.(string)) return } channel := manager.DataChannels[target.(string)] for { if atomic.CompareAndSwapInt32(channel.l, 0, 1) { defer atomic.SwapInt32(channel.l, 0) break } } channels = append(channels, channel) } manager.DataChannelMapMux.RUnlock() if err = w.uploadDone(squadId, from, payload["filename"].(string), channels); err != nil { errCh <- err return } case DOWNLOAD: if _, ok := payload["filename"]; !ok { errCh <- fmt.Errorf("no field filename in payload") return } if _, ok := payload["filename"].(string); !ok { errCh <- fmt.Errorf("field filename in payload is not a string") return } if _, ok := payload["peerId"]; !ok { errCh <- fmt.Errorf("no field peerId in payload") return } if _, ok := payload["peerId"].(string); !ok { errCh <- fmt.Errorf("field peerId in payload is not a string") return } manager.DataChannelMapMux.RLock() if _, ok := manager.DataChannels[payload["peerId"].(string)]; !ok { manager.DataChannelMapMux.RUnlock() errCh <- fmt.Errorf("no corresponding datachannel") return } channel := manager.DataChannels[payload["peerId"].(string)] for { if atomic.CompareAndSwapInt32(channel.l, 0, 1) { logger.Println("atomic lock unlocked") defer atomic.SwapInt32(channel.l, 0) break } } manager.DataChannelMapMux.RUnlock() if err = w.download(squadId, from, payload["filename"].(string), channel.DataChannel); err != nil { errCh <- err return } } done <- struct{}{} }() select { case <-done: return nil case err = <-errCh: return } } func (w *WebrtcCallFileManager) initUpload(squadId string, from string, fileName string) (err error) { for { if atomic.CompareAndSwapInt32(w.l, 0, 1) { defer atomic.SwapInt32(w.l, 0) if _, dirErr := os.Stat(filepath.Join("data", "squads", squadId)); os.IsNotExist(dirErr) { if err = os.MkdirAll(filepath.Join("data", "squads", squadId), 0700); err != nil { return } } f, fErr := os.Create(filepath.Join("data", "squads", squadId, fileName)) if err != nil { return fErr } f.Close() f, fErr = os.OpenFile(filepath.Join("data", "squads", squadId, fileName), os.O_APPEND|os.O_WRONLY, 0644) if err != nil { return fErr } w.files[fileName] = f break } else { continue } } return } func (w *WebrtcCallFileManager) upload(squadId string, from string, fileName string, data []byte) (err error) { for { if atomic.CompareAndSwapInt32(w.l, 0, 1) { defer atomic.SwapInt32(w.l, 0) if _, ok := w.files[fileName]; !ok { err = fmt.Errorf("no open file with name %s", fileName) return } _, err = w.files[fileName].Write(data) break } else { continue } } return } func (w *WebrtcCallFileManager) uploadDone(squadId string, from string, fileName string, channels []*DataChannel) (err error) { for { if atomic.CompareAndSwapInt32(w.l, 0, 1) { defer atomic.SwapInt32(w.l, 0) if _, ok := w.files[fileName]; !ok { err = fmt.Errorf("no open file with name %s", fileName) return } err = w.files[fileName].Close() delete(w.files, fileName) bsInit, jsonErr := json.Marshal(map[string]interface{}{ "type": UPLOAD_DONE, "from": "server", "payload": map[string]string{ "path": fileName, }, }) if jsonErr != nil { return jsonErr } for _, channel := range channels { if err = channel.DataChannel.SendText(string(bsInit)); err != nil { return } } break } else { continue } } return } func (w *WebrtcCallFileManager) download(squadId string, dst string, fileName string, channel *webrtc.DataChannel) (err error) { logger.Println("got called") if _, dirErr := os.Stat(filepath.Join("data", "squads", squadId, fileName)); os.IsNotExist(dirErr) { logger.Println("file does not exist :", filepath.Join("data", "squads", squadId, fileName)) return } f, err := os.Open(filepath.Join("data", "squads", squadId, fileName)) if err != nil { return } defer f.Close() bsInit, err := json.Marshal(map[string]interface{}{ "type": HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_INIT, "from": "server", "payload": map[string]string{ "path": fileName, }, }) if err != nil { return } if err = channel.SendText(string(bsInit)); err != nil { return } r := bufio.NewReader(f) buf := make([]byte, 0, 30000) logger.Println("start reading") for { n, readErr := r.Read(buf[:cap(buf)]) buf = buf[:n] if n == 0 { if err == nil { logger.Println("n is 0 weird") break } if err == io.EOF { break } log.Fatal(readErr) } bs, jsonErr := json.Marshal(map[string]interface{}{ "type": HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE, "from": "server", "payload": map[string]interface{}{ "path": fileName, "content": buf, }, }) if jsonErr != nil { return jsonErr } if err = channel.SendText(string(bs)); err != nil { return } } logger.Println("stop reading") bs, err := json.Marshal(map[string]interface{}{ "type": HOSTED_SQUAD_DOWNLOAD_FILE_RESPONSE_END, "from": "server", "payload": map[string]string{ "path": fileName, }, }) if err != nil { return } if err = channel.SendText(string(bs)); err != nil { return } logger.Println("done") return }