diff --git a/extension/frame.go b/extension/frame.go new file mode 100644 index 0000000..dc1bd38 --- /dev/null +++ b/extension/frame.go @@ -0,0 +1,371 @@ +package extension + +import ( + "sync" + "time" +) + +// ==================== 媒体类型定义 ==================== + +type TrackType int + +const ( + TrackInvalid TrackType = iota - 1 + TrackVideo + TrackAudio + TrackTitle + TrackApplication +) + +func (t TrackType) String() string { + switch t { + case TrackVideo: + return "video" + case TrackAudio: + return "audio" + case TrackApplication: + return "application" + default: + return "invalid" + } +} + +func GetTrackType(str string) TrackType { + switch str { + case "video": + return TrackVideo + case "audio": + return TrackAudio + case "application": + return TrackApplication + default: + return TrackInvalid + } +} + +// ==================== 编解码标识 ==================== + +type CodecID int + +const ( + CodecInvalid CodecID = iota - 1 + CodecH264 + CodecH265 + CodecAAC + CodecG711A + CodecG711U + CodecOpus + // ... 其他编解码器定义 +) + +var codecNames = map[CodecID]string{ + CodecH264: "H264", + CodecH265: "H265", + CodecAAC: "mpeg4-generic", + CodecG711A: "PCMA", + CodecG711U: "PCMU", + CodecOpus: "opus", +} + +func (c CodecID) String() string { + if name, ok := codecNames[c]; ok { + return name + } + return "invalid" +} + +func GetCodecID(name string) CodecID { + for cid, n := range codecNames { + if n == name { + return cid + } + } + return CodecInvalid +} + +// ==================== 帧接口定义 ==================== + +type Frame interface { + DTS() uint64 + PTS() uint64 + PrefixSize() int + KeyFrame() bool + ConfigFrame() bool + CacheAble() bool + DropAble() bool + DecodeAble() bool + Data() []byte + Size() int + CodecID() CodecID + TrackType() TrackType +} + +// ==================== 基础帧实现 ==================== + +type BaseFrame struct { + Dts uint64 + Pts uint64 + VPrefixSize int + Codec CodecID + Track TrackType +} + +func (f *BaseFrame) DTS() uint64 { return f.Dts } +func (f *BaseFrame) PTS() uint64 { return f.Pts } +func (f *BaseFrame) PrefixSize() int { return f.VPrefixSize } +func (f *BaseFrame) CodecID() CodecID { return f.Codec } +func (f *BaseFrame) TrackType() TrackType { return f.Track } + +// ==================== 帧包装器 ==================== + +type FrameFromBytes struct { + *BaseFrame + DataPtr []byte + IsKey bool + IsConfig bool +} + +func (f *FrameFromBytes) KeyFrame() bool { return f.IsKey } +func (f *FrameFromBytes) ConfigFrame() bool { return f.IsConfig } +func (f *FrameFromBytes) CacheAble() bool { return false } +func (f *FrameFromBytes) DropAble() bool { return false } +func (f *FrameFromBytes) DecodeAble() bool { + if f.TrackType() != TrackVideo { + return true + } + return !f.ConfigFrame() +} +func (f *FrameFromBytes) Data() []byte { return f.DataPtr } +func (f *FrameFromBytes) Size() int { return len(f.DataPtr) } + +// ==================== 可缓存帧 ==================== + +type FrameCacheAble struct { + *FrameFromBytes + Buffer []byte +} + +func NewFrameCacheAble(frame Frame, forceKeyFrame bool) Frame { + if frame.CacheAble() { + return frame + } + + buffer := make([]byte, frame.Size()) + copy(buffer, frame.Data()) + + return &FrameCacheAble{ + FrameFromBytes: &FrameFromBytes{ + BaseFrame: &BaseFrame{ + Dts: frame.DTS(), + Pts: frame.PTS(), + VPrefixSize: frame.PrefixSize(), + Codec: frame.CodecID(), + Track: frame.TrackType(), + }, + DataPtr: buffer, + IsKey: forceKeyFrame || frame.KeyFrame(), + IsConfig: frame.ConfigFrame(), + }, + Buffer: buffer, + } +} + +func (f *FrameCacheAble) CacheAble() bool { return true } + +// ==================== 时间戳修正 ==================== + +type FrameStamp struct { + Frame + Dts int64 + Pts int64 +} + +func NewFrameStamp(frame Frame, dts, pts int64) Frame { + return &FrameStamp{Frame: frame, Dts: dts, Pts: pts} +} + +func (f *FrameStamp) DTS() uint64 { return uint64(f.Dts) } +func (f *FrameStamp) PTS() uint64 { return uint64(f.Pts) } + +// ==================== 帧合并器 ==================== + +type FrameMerger struct { + frameCache []Frame + haveDecodeable bool + mergeType int +} + +const ( + MergeNone = iota + MergeH264Prefix + MergeMP4NalSize +) + +func NewFrameMerger(mergeType int) *FrameMerger { + return &FrameMerger{mergeType: mergeType} +} + +func (m *FrameMerger) WillFlush(frame Frame) bool { + if len(m.frameCache) == 0 { + return false + } + if frame == nil { + return true + } + + last := m.frameCache[len(m.frameCache)-1] + switch m.mergeType { + case MergeNone: + return last.DTS() != frame.DTS() || len(m.frameCache) > 100 + case MergeH264Prefix, MergeMP4NalSize: + return last.DTS() != frame.DTS() || frame.DecodeAble() || frame.ConfigFrame() || len(m.frameCache) > 100 + default: + return true + } +} + +func (m *FrameMerger) DoMerge(buffer *[]byte, frame Frame) { + switch m.mergeType { + case MergeNone: + *buffer = append(*buffer, frame.Data()...) + case MergeH264Prefix: + if frame.PrefixSize() == 0 { + *buffer = append(*buffer, 0, 0, 0, 1) + } + // 确保每个NAL单元都有独立的起始码 + *buffer = append(*buffer, frame.Data()...) + case MergeMP4NalSize: + size := uint32(frame.Size() - frame.PrefixSize()) + *buffer = append(*buffer, byte(size>>24), byte(size>>16), byte(size>>8), byte(size)) + *buffer = append(*buffer, frame.Data()[frame.PrefixSize():]...) + } +} + +func (m *FrameMerger) InputFrame(frame Frame, cb func(dts, pts uint64, buffer []byte, haveKeyFrame bool)) bool { + if frame == nil { + return false + } + + if frame.DecodeAble() { + m.haveDecodeable = true + } + m.frameCache = append(m.frameCache, frame) + + if frame != nil && !m.NeedMerge(frame.CodecID()) { + cb(frame.DTS(), frame.PTS(), frame.Data(), frame.KeyFrame()) + return true + } + + if m.WillFlush(frame) { + last := m.frameCache[len(m.frameCache)-1] + var buffer []byte + var haveKeyFrame = last.KeyFrame() + + if len(m.frameCache) > 1 || m.mergeType == MergeMP4NalSize { + buffer = make([]byte, 0, last.Size()+1024) + for _, f := range m.frameCache { + m.DoMerge(&buffer, f) + if f.KeyFrame() { + haveKeyFrame = true + } + } + cb(last.DTS(), last.PTS(), buffer, haveKeyFrame) + m.frameCache = m.frameCache[:0] + m.haveDecodeable = false + } + + } + + return true +} + +func (m *FrameMerger) NeedMerge(codec CodecID) bool { + return codec == CodecH264 || codec == CodecH265 +} + +func (m *FrameMerger) Flush(cb func(dts, pts uint64, buffer []byte, haveKeyFrame bool)) { + if len(m.frameCache) > 0 { + m.InputFrame(nil, cb) + } + m.frameCache = m.frameCache[:0] + m.haveDecodeable = false +} + +// ==================== 帧分发器 ==================== + +type FrameDispatcher struct { + delegates map[*FrameWriter]bool + mtx sync.Mutex + // 统计信息 + videoKeyFrames uint64 + frames uint64 + lastFrames uint64 + gopSize uint64 + gopInterval time.Duration + lastKeyFrameTS uint64 +} + +func NewFrameDispatcher() *FrameDispatcher { + return &FrameDispatcher{delegates: make(map[*FrameWriter]bool)} +} + +func (d *FrameDispatcher) AddDelegate(w *FrameWriter) { + d.mtx.Lock() + defer d.mtx.Unlock() + d.delegates[w] = true +} + +func (d *FrameDispatcher) DelDelegate(w *FrameWriter) { + d.mtx.Lock() + defer d.mtx.Unlock() + delete(d.delegates, w) +} + +func (d *FrameDispatcher) InputFrame(frame Frame) bool { + if !frame.ConfigFrame() && !frame.DropAble() { + d.frames++ + if frame.KeyFrame() && frame.TrackType() == TrackVideo { + d.videoKeyFrames++ + d.gopSize = d.frames - d.lastFrames + d.gopInterval = time.Duration(frame.DTS()-d.lastKeyFrameTS) * time.Millisecond + d.lastFrames = d.frames + d.lastKeyFrameTS = frame.DTS() + } + } + + d.mtx.Lock() + defer d.mtx.Unlock() + + var ret bool + for w := range d.delegates { + if w.InputFrame(frame) { + ret = true + } + } + return ret +} + +func (d *FrameDispatcher) GetStats() (uint64, uint64, uint64, time.Duration) { + return d.frames, d.videoKeyFrames, d.gopSize, d.gopInterval +} + +// Size 返回当前注册的代理数量 +// Size returns the number of registered delegates +func (d *FrameDispatcher) Size() int { + d.mtx.Lock() + defer d.mtx.Unlock() + return len(d.delegates) +} + +// ==================== 帧写入接口 ==================== + +type FrameWriter struct { + OnFrame func(Frame) bool +} + +func (w *FrameWriter) InputFrame(frame Frame) bool { + if w.OnFrame != nil { + return w.OnFrame(frame) + } + return false +} diff --git a/extension/frame_test.go b/extension/frame_test.go new file mode 100644 index 0000000..09be02e --- /dev/null +++ b/extension/frame_test.go @@ -0,0 +1,183 @@ +package extension + +import ( + "sync" + "testing" + "time" +) + +func TestTrackTypeConversion(t *testing.T) { + // 验证字符串转TrackType + if GetTrackType("video") != TrackVideo { + t.Error("video track type conversion failed") + } + if GetTrackType("invalid") != TrackInvalid { + t.Error("invalid track type conversion failed") + } + + // 验证TrackType转字符串 + if TrackVideo.String() != "video" { + t.Error("TrackVideo string conversion failed") + } +} + +func TestCodecIDConversion(t *testing.T) { + // 验证名称转CodecID + if GetCodecID("H264") != CodecH264 { + t.Error("H264 codec conversion failed") + } + if GetCodecID("invalid") != CodecInvalid { + t.Error("invalid codec conversion failed") + } + + // 验证CodecID转名称 + if CodecH264.String() != "H264" { + t.Error("CodecH264 string conversion failed") + } +} + +func TestBaseFrame(t *testing.T) { + frame := &BaseFrame{ + Dts: 1000, + Pts: 1005, + VPrefixSize: 4, + Codec: CodecH264, + Track: TrackVideo, + } + + if frame.DTS() != 1000 { + t.Error("DTS mismatch") + } + if frame.PrefixSize() != 4 { + t.Error("PrefixSize mismatch") + } +} + +func TestFrameCacheAble(t *testing.T) { + // 测试原始帧不可缓存 + raw := &FrameFromBytes{ + BaseFrame: &BaseFrame{Dts: 1000}, + DataPtr: []byte{1, 2, 3}, + } + if raw.CacheAble() { + t.Error("raw frame should not be cacheable") + } + + // 测试转换后可缓存 + cached := NewFrameCacheAble(raw, false) + if !cached.CacheAble() { + t.Error("cached frame should be cacheable") + } + if cached.Size() != 3 { + t.Error("data size mismatch") + } +} + +func TestFrameMerger_H264Prefix(t *testing.T) { + merger := NewFrameMerger(MergeH264Prefix) + var output []byte + + // 创建测试帧(无前缀) + frame1 := &FrameFromBytes{BaseFrame: &BaseFrame{Dts: 1000, VPrefixSize: 0}, DataPtr: []byte{0x67}} + frame2 := &FrameFromBytes{BaseFrame: &BaseFrame{Dts: 1000, VPrefixSize: 0}, DataPtr: []byte{0x68}} + + // 合并相同时间戳的帧 + merger.InputFrame(frame1, func(dts, pts uint64, buf []byte, key bool) { + output = append(output, buf...) + }) + merger.InputFrame(frame2, func(dts, pts uint64, buf []byte, key bool) { + output = append(output, buf...) + }) + merger.Flush(func(dts, pts uint64, buf []byte, key bool) { + output = append(output, buf...) + }) + + // 验证H264前缀添加 + expected := []byte{0, 0, 0, 1, 0x67, 0, 0, 0, 1, 0x68} + if string(output) != string(expected) { + t.Error("H264 prefix merge failed") + } +} + +func TestFrameMerger_MP4NalSize(t *testing.T) { + merger := NewFrameMerger(MergeMP4NalSize) + var output []byte + + // 创建测试帧(带4字节前缀) + frame := &FrameFromBytes{ + BaseFrame: &BaseFrame{VPrefixSize: 4, Dts: 1000}, + DataPtr: []byte{0, 0, 0, 1, 0x67}, + } + + merger.InputFrame(frame, func(dts, pts uint64, buf []byte, key bool) { + output = buf + }) + merger.Flush(func(dts, pts uint64, buf []byte, key bool) { + output = buf + }) + + // 验证MP4 NALU大小(4字节长度 + 原始数据) + if len(output) != 5 { + t.Error("MP4 nal size length mismatch") + } + if output[0] != 0 || output[1] != 0 || output[2] != 0 || output[3] != 1 { + t.Error("MP4 nal size header mismatch") + } +} + +func TestFrameDispatcher(t *testing.T) { + dispatcher := NewFrameDispatcher() + statsChan := make(chan struct{}) + + // 添加测试代理 + dispatcher.AddDelegate(&FrameWriter{ + OnFrame: func(frame Frame) bool { + if frame.KeyFrame() && frame.TrackType() == TrackVideo { + close(statsChan) + } + return true + }, + }) + + // 输入视频关键帧 + dispatcher.InputFrame(&FrameFromBytes{ + BaseFrame: &BaseFrame{ + Dts: 1000, + Track: TrackVideo, + }, + IsKey: true, + }) + + // 验证GOP统计 + select { + case <-statsChan: + frames, keys, gopSize, _ := dispatcher.GetStats() + if frames != 1 || keys != 1 || gopSize != 1 { + t.Error("GOP statistics mismatch") + } + case <-time.After(100 * time.Millisecond): + t.Error("timeout waiting for stats") + } +} + +func TestFrameDispatcher_Concurrent(t *testing.T) { + dispatcher := NewFrameDispatcher() + var done sync.WaitGroup + + // 并发添加/删除代理 + for i := 0; i < 10; i++ { + done.Add(1) + go func(id int) { + defer done.Done() + writer := &FrameWriter{} + dispatcher.AddDelegate(writer) + dispatcher.DelDelegate(writer) + }(i) + } + done.Wait() + + // 验证线程安全 + if dispatcher.Size() != 0 { + t.Error("concurrent delegate management failed") + } +} diff --git a/go.mod b/go.mod index 411b23e..f827976 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module git.kingecg.top/kingecg/goZLMediaKit -go 1.23.1 +go 1.23