diff --git a/internal/base/middleware/recovery.go b/internal/base/middleware/recovery.go new file mode 100644 index 000000000..7e844f78e --- /dev/null +++ b/internal/base/middleware/recovery.go @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package middleware + +import ( + "net/http" + "runtime/debug" + + "github.com/apache/answer/internal/base/handler" + "github.com/apache/answer/internal/base/reason" + "github.com/gin-gonic/gin" + "github.com/segmentfault/pacman/log" +) + +func Recovery() gin.HandlerFunc { + return func(ctx *gin.Context) { + defer func() { + if err := recover(); err != nil { + log.Errorf("panic recovered: %v\n%s", err, debug.Stack()) + ctx.AbortWithStatusJSON(http.StatusInternalServerError, + handler.NewRespBody(http.StatusInternalServerError, reason.UnknownError).TrMsg(handler.GetLangByCtx(ctx)), + ) + } + }() + ctx.Next() + } +} diff --git a/internal/base/middleware/recovery_test.go b/internal/base/middleware/recovery_test.go new file mode 100644 index 000000000..c01719fce --- /dev/null +++ b/internal/base/middleware/recovery_test.go @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestRecovery_Panic(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(Recovery()) + r.GET("/panic", func(ctx *gin.Context) { + panic("test panic") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/panic", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } + + var body map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + if body["reason"] != "base.unknown" { + t.Errorf("unexpected reason: %v", body["reason"]) + } +} + +func TestRecovery_NoPanic(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(Recovery()) + r.GET("/ok", func(ctx *gin.Context) { + ctx.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/ok", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} diff --git a/internal/base/server/http.go b/internal/base/server/http.go index 765cbf6be..1e8204d36 100644 --- a/internal/base/server/http.go +++ b/internal/base/server/http.go @@ -52,6 +52,7 @@ func NewHTTPServer(debug bool, gin.SetMode(gin.ReleaseMode) } r := gin.New() + r.Use(middleware.Recovery()) r.Use(func(ctx *gin.Context) { if strings.Contains(ctx.Request.URL.Path, "/chat/completions") { return