diff --git a/server/conn.go b/server/conn.go index e228d84..d85200f 100644 --- a/server/conn.go +++ b/server/conn.go @@ -137,6 +137,12 @@ func (slf *Conn) WriteString(data string, messageType ...int) { slf.Write([]byte(data), messageType...) } +// WriteStringWithCallback 与 WriteString 相同,但是会在写入完成后调用 callback +// - 当 callback 为 nil 时,与 WriteString 相同 +func (slf *Conn) WriteStringWithCallback(data string, callback func(err error), messageType ...int) { + slf.WriteWithCallback([]byte(data), callback, messageType...) +} + // Write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Conn) Write(data []byte, messageType ...int) { @@ -153,6 +159,23 @@ func (slf *Conn) Write(data []byte, messageType ...int) { slf.mutex.Unlock() } +// WriteWithCallback 与 Write 相同,但是会在写入完成后调用 callback +// - 当 callback 为 nil 时,与 Write 相同 +func (slf *Conn) WriteWithCallback(data []byte, callback func(err error), messageType ...int) { + if slf.packetPool == nil { + return + } + cp := slf.packetPool.Get() + if len(messageType) > 0 { + cp.websocketMessageType = messageType[0] + } + cp.packet = data + cp.callback = callback + slf.mutex.Lock() + slf.packets = append(slf.packets, cp) + slf.mutex.Unlock() +} + // writeLoop 写循环 func (slf *Conn) writeLoop(wait *sync.WaitGroup) { slf.packetPool = synchronization.NewPool[*connPacket](10*1024, @@ -161,6 +184,7 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { }, func(data *connPacket) { data.packet = nil data.websocketMessageType = 0 + data.callback = nil }, ) defer func() { @@ -210,9 +234,14 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { _, err = slf.kcp.Write(data.packet) } } + callback := data.callback slf.packetPool.Release(data) if err != nil { - panic(err) + if callback != nil { + callback(err) + } else { + panic(err) + } } } } diff --git a/server/conn_packet.go b/server/conn_packet.go index 6ee6684..0244517 100644 --- a/server/conn_packet.go +++ b/server/conn_packet.go @@ -3,4 +3,5 @@ package server type connPacket struct { websocketMessageType int packet []byte + callback func(err error) }