package authz import ( "bytes" "context" "fmt" "io" "log" "net/http" "net/http/httptest" "testing" "git.ofmax.li/go-git-server/internal/admin" ) func junkTestHandler() http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) _, err := rw.Write([]byte("Im a body")) if err != nil { log.Fatalf("couldn't write http body %s", err) } } } func TestAuthentication(t *testing.T) { badToken, _, _ := GenerateNewToken() token, hash, _ := GenerateNewToken() okUserName := "tester" badUserName := "badb00" tm := TokenMap{} tm["uid:tester"] = hash cases := []struct { description string username string token string tm TokenMap statusCode int handler http.HandlerFunc }{ { username: okUserName, token: token, tm: tm, statusCode: http.StatusOK, description: "Good Login", handler: func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() uid := ctx.Value(AuthzUrnKey) if uid != fmt.Sprintf("uid:%s", okUserName) { t.Fatal("Context UID not set") } }, }, { username: badUserName, token: token, tm: tm, statusCode: http.StatusForbidden, description: "Bad usename", handler: junkTestHandler(), }, { username: okUserName, token: badToken, tm: tm, statusCode: http.StatusForbidden, description: "Bad token", handler: junkTestHandler(), }, } for _, tc := range cases { authHandler := Authentication(tc.tm, tc.handler) req := httptest.NewRequest(http.MethodGet, "https://git.ofmax.li", nil) req.SetBasicAuth(tc.username, tc.token) recorder := httptest.NewRecorder() authHandler.ServeHTTP(recorder, req) result := recorder.Result() defer result.Body.Close() if result.StatusCode != tc.statusCode { t.Fatalf("Test Case %s failed Expected: %d Found: %d", tc.description, tc.statusCode, result.StatusCode) } t.Logf("Test Case: %s Expected: %d Found: %d", tc.description, tc.statusCode, result.StatusCode) } } func TestAuthorization(t *testing.T) { t.Log("Starting authorization tests") baseURL := "http://test" cases := []struct { url string user string expectedStatus int description string body []byte }{ { url: fmt.Sprintf("%s/%s", baseURL, "repo/url"), user: "uid:jack", expectedStatus: 200, description: "an authorized action should yield a 200", body: []byte("Im a body"), }, { url: fmt.Sprintf("%s/%s", baseURL, "repo/url/bar"), user: "uid:chumba", expectedStatus: 403, description: "an unauthorized action should yield a 403", body: []byte("Access denied\n"), }, { url: fmt.Sprintf("%s/%s", baseURL, "repo/url/bar"), user: "anon", expectedStatus: http.StatusUnauthorized, description: "an unauthorized action should yield a 403", body: []byte("Authentication Required\n"), }, } svcr, _ := admin.NewService( "../../auth_model.ini", "../../tests/testpolicy.csv", "../../gitserver.yaml", "../../repos", false) for _, tc := range cases { t.Logf("test case: %s", tc.description) authHandler := Authorization(svcr, junkTestHandler()) recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, tc.url, nil) ctx := req.Context() ctx = context.WithValue(ctx, AuthzUrnKey, tc.user) req = req.WithContext(ctx) authHandler.ServeHTTP(recorder, req) result := recorder.Result() defer result.Body.Close() body, err := io.ReadAll(result.Body) if err != nil { t.Fatal("couldn't read response body") } if result.StatusCode != tc.expectedStatus { t.Fatalf("Test Case %s failed Expected: %d Found: %d", tc.description, tc.expectedStatus, result.StatusCode) } if !bytes.Equal(body, tc.body) { t.Fatalf("Test Case %s failed Expected: %d Found: %d", tc.description, tc.body, body) } } }