diff --git a/server/network.go b/server/network.go index 3e5c1f0..7fadec2 100644 --- a/server/network.go +++ b/server/network.go @@ -103,6 +103,7 @@ func (n Network) adaptation(srv *Server) <-chan error { switch n { case NetworkNone: srv.addr = "-" + state <- nil case NetworkTcp: n.gNetMode(state, srv) case NetworkTcp4: @@ -125,6 +126,8 @@ func (n Network) adaptation(srv *Server) <-chan error { n.kcpMode(state, srv) case NetworkGRPC: n.grpcMode(state, srv) + default: + state <- fmt.Errorf("unsupported network mode: %s", n) } return state } @@ -248,7 +251,8 @@ func (n Network) websocketMode(state chan<- error, srv *Server) { if srv.websocketUpgrader == nil { srv.websocketUpgrader = DefaultWebsocketUpgrader() } - http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { + mux := http.NewServeMux() + mux.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { ip := request.Header.Get("X-Real-IP") ws, err := srv.websocketUpgrader.Upgrade(writer, request, nil) if err != nil { @@ -304,17 +308,17 @@ func (n Network) websocketMode(state chan<- error, srv *Server) { srv.PushPacketMessage(conn, messageType, packet) } }) - go func(lis *listener) { + go func(lis *listener, mux *http.ServeMux) { var err error if len(lis.srv.certFile)+len(lis.srv.keyFile) > 0 { - err = http.ServeTLS(lis, nil, lis.srv.certFile, lis.srv.keyFile) + err = http.ServeTLS(lis, mux, lis.srv.certFile, lis.srv.keyFile) } else { - err = http.Serve(lis, nil) + err = http.Serve(lis, mux) } if err != nil { super.TryWriteChannel(lis.state, err) } - }((&listener{srv: srv, Listener: l, state: state}).init()) + }((&listener{srv: srv, Listener: l, state: state}).init(), mux) } // IsSocket 返回当前服务器的网络模式是否为 Socket 模式