aboutsummaryrefslogtreecommitdiff
path: root/internal/auth/service_test.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--internal/auth/service_test.go133
1 files changed, 133 insertions, 0 deletions
diff --git a/internal/auth/service_test.go b/internal/auth/service_test.go
new file mode 100644
index 0000000..72ff709
--- /dev/null
+++ b/internal/auth/service_test.go
@@ -0,0 +1,133 @@
+package auth_test
+
+import (
+ "errors"
+ "strings"
+ "testing"
+
+ "golang.org/x/oauth2"
+
+ "github.com/brianvoe/gofakeit"
+ "github.com/golang/mock/gomock"
+
+ "git.ofmax.li/iserv/internal/auth"
+ "git.ofmax.li/iserv/internal/mock/mock_auth"
+)
+
+type serviceSuite struct {
+ as auth.Servicer
+ repo *mock_auth.MockRepo
+}
+
+func TestServices(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ repo := mock_auth.NewMockRepo(ctrl)
+ defer ctrl.Finish()
+ ts := &serviceSuite{
+ auth.NewService(repo),
+ repo,
+ }
+ t.Run("token generation", ts.testGenStateToken())
+ t.Run("token validation", ts.testValidateStateToken())
+ t.Run("auth profile registration", ts.testLoginOrRegsiterSessionId())
+}
+
+func (s *serviceSuite) testGenStateToken() func(t *testing.T) {
+ return func(t *testing.T) {
+ token, _ := s.as.GenerateStateToken()
+ parts := strings.Split(token, ".")
+ if len(parts) != 3 {
+ t.Errorf("token doesn't match format")
+ }
+ }
+}
+
+func (s *serviceSuite) testValidateStateToken() func(t *testing.T) {
+ validToken, _ := s.as.GenerateStateToken()
+ return func(t *testing.T) {
+ cases := []struct {
+ name string
+ token string
+ sessionToken string
+ wanted bool
+ errz error
+ }{
+ {name: "valid matching token",
+ token: validToken,
+ sessionToken: validToken,
+ wanted: true, errz: nil},
+ {name: "non matching token",
+ token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
+ sessionToken: "sadfsaf",
+ wanted: false, errz: auth.ErrInvalidToken},
+ {name: "matching but not real",
+ token: "eyJhbGciOTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
+ sessionToken: "eyJhbGciOTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
+ wanted: false, errz: auth.ErrInvalidJWT},
+ }
+ for _, tc := range cases {
+ isValid, err := s.as.ValidateStateToken(tc.token, tc.sessionToken)
+ if (isValid != tc.wanted) || (err != tc.errz) {
+ t.Fatalf("%s: expected: %v, got: %v err: %v errgot: %v", tc.name, tc.wanted, isValid, tc.errz, err)
+ }
+ }
+ }
+}
+
+func (s *serviceSuite) testLoginOrRegsiterSessionId() func(t *testing.T) {
+ return func(t *testing.T) {
+ // is authorized err
+ gp := &auth.GoogleAuthProfile{}
+ token := &oauth2.Token{}
+ gofakeit.Struct(token)
+ gofakeit.Struct(gp)
+ s.repo.
+ EXPECT().
+ IsAuthorized(gomock.Any()).
+ Return(false, errors.New("foo"))
+ id, isNew, err := s.as.LoginOrRegisterSessionID(token, gp)
+ if err == nil && id == "" && isNew == false {
+ t.Fatalf("%s", id)
+ }
+ // not authorized no error
+ s.repo.
+ EXPECT().
+ IsAuthorized(gomock.Any()).
+ Return(false, nil)
+ id, isNew, err = s.as.LoginOrRegisterSessionID(token, gp)
+ if err != auth.ErrUnauthorized || id != "" || isNew != false {
+ t.Fatalf("unauthorized isnew: %v id: %v err: %v", isNew, id, err.Error())
+ }
+
+ // authorized
+ s.repo.
+ EXPECT().
+ IsAuthorized(gomock.Any()).
+ Return(true, nil)
+ s.repo.
+ EXPECT().
+ LookUpAuthProfileID(gomock.Any()).
+ Return("asdfsafaf", nil)
+ id, isNew, err = s.as.LoginOrRegisterSessionID(token, gp)
+ if err != nil && isNew == false && id == "" {
+ t.Fatalf("auth profile exists isnew: %v id: %v err: %v", isNew, id, err.Error())
+ }
+ // new registration
+ s.repo.
+ EXPECT().
+ IsAuthorized(gomock.Any()).
+ Return(true, nil)
+ s.repo.
+ EXPECT().
+ LookUpAuthProfileID(gomock.Any()).
+ Return("", nil)
+ s.repo.
+ EXPECT().
+ SaveAuthProfile(gomock.Any()).
+ Return(nil)
+ id, isNew, err = s.as.LoginOrRegisterSessionID(token, gp)
+ if err != nil && isNew == false && id == "" {
+ t.Fatalf("isnew: %v id: %v err: %v", isNew, id, err.Error())
+ }
+ }
+}