From 5157f78430764b582418d5af292f59af617e7206 Mon Sep 17 00:00:00 2001 From: Radhakrishna Sanka Date: Sun, 3 Dec 2023 21:10:01 +0530 Subject: [PATCH] Updated the basic documentation example Improving code clarity I'm starting to make some changes to the code documentation here. I'm including a bit of refactoring to make sure that things are more readable. - The code extensively uses the short variable names to the detriment of the readability. This compounds the problem of not having enough code documentation. - While golang recommends short variable names, the intention is to avoid long nested function calls and to avoid diverting the attention from the main logic of the program. However, even the main logic variables have really short names making it almost impossible to read. Incremental renaming of variables in intro.go First pass of the clarification renaming: 1. Changed all type structs were Pascal Case 2. Expanded 2-3 character field names for structs 3. Tried removing redundant functions Add comments and update function names Add new comments and function names for better code clarity Updated go go project files Undid struct naming that pushed stuff to be public Undid the public struct scoping --- auth/auth.go | 6 + auth/internal/rails/auth.go | 7 +- auth/internal/rails/cookie.go | 2 + auth/jwt.go | 2 + auth/provider/auth0.go | 5 + auth/provider/firebase.go | 6 + auth/provider/generic.go | 5 + auth/provider/jwks.go | 7 ++ auth/provider/provider.go | 10 +- auth/rails.go | 5 + cmd/cmd.go | 7 +- cmd/cmd_admin.go | 5 + cmd/cmd_db.go | 2 + cmd/cmd_migrate.go | 11 +- cmd/cmd_new.go | 2 + cmd/cmd_secrets.go | 2 +- cmd/cmd_seed.go | 8 +- cmd/cmd_serv.go | 2 + cmd/cmd_version.go | 3 + conf/config.go | 3 + core/api.go | 201 ++++++++++++++++--------------- core/args.go | 4 +- core/cache.go | 5 +- core/config.go | 2 +- core/core.go | 43 +++---- core/crypt.go | 10 ++ core/gstate.go | 30 ++--- core/init.go | 75 +++++++----- core/internal/allow/allow.go | 31 +++-- core/internal/allow/gql.go | 2 + core/internal/assert/assert.go | 4 + core/internal/graph/lex.go | 22 ++-- core/internal/graph/utils.go | 3 + core/internal/sdata/dwg.go | 55 ++++++--- core/internal/sdata/schema.go | 96 +++++++++------ core/internal/sdata/strings.go | 4 + core/internal/sdata/tables.go | 17 +++ core/internal/util/graph.go | 8 ++ core/internal/util/graph_test.go | 2 +- core/intro.go | 84 +++++++++---- core/osfs.go | 15 ++- core/remote_api.go | 4 +- core/remote_join.go | 6 +- core/resolve.go | 9 +- core/rolestmt.go | 8 +- core/schema.go | 3 + core/subs.go | 65 ++++++---- core/trace.go | 6 + core/watcher.go | 6 +- go.work.sum | 1 + serv/admin.go | 16 ++- serv/afero.go | 4 + serv/api.go | 78 ++++++------ serv/client.go | 4 + serv/config.go | 99 ++++++++------- serv/db.go | 65 +++++----- serv/deploy.go | 19 ++- serv/filewatch.go | 7 +- serv/health.go | 5 +- serv/http.go | 35 ++++-- serv/init.go | 18 ++- serv/internal/secrets/decrypt.go | 3 + serv/internal/secrets/edit.go | 7 ++ serv/internal/secrets/init.go | 5 +- serv/internal/secrets/rotate.go | 1 + serv/internal/secrets/run.go | 2 + serv/internal/util/log.go | 2 + serv/internal/util/viper.go | 1 + serv/iplimiter.go | 7 +- serv/migrate.go | 2 + serv/routes.go | 5 +- serv/secrets.go | 3 + serv/serv.go | 16 ++- serv/telemetry.go | 7 +- serv/webui.go | 1 + serv/ws.go | 11 +- tests/core_test.go | 6 +- 77 files changed, 869 insertions(+), 481 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 7a255327..a4a7a0cf 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -172,6 +172,7 @@ func NewAuthHandlerFunc(ac Auth) (HandlerFunc, error) { return h, err } +// NoAuth returns a handler that does not perform any authentication. func NoAuth() (HandlerFunc, error) { return func(w http.ResponseWriter, r *http.Request) (context.Context, error) { return r.Context(), nil @@ -232,6 +233,7 @@ func NewAuth(ac Auth, log *zap.Logger, opt Options, hFn ...HandlerFunc) ( }, nil } +// SimpleHandler is a simple auth handler that sets the user ID, provider and role func SimpleHandler(ac Auth) (HandlerFunc, error) { return func(_ http.ResponseWriter, r *http.Request) (context.Context, error) { c := r.Context() @@ -257,6 +259,7 @@ func SimpleHandler(ac Auth) (HandlerFunc, error) { var Err401 = errors.New("401 unauthorized") +// HeaderHandler is a middleware that checks for a header value func HeaderHandler(ac Auth) (HandlerFunc, error) { hdr := ac.Header @@ -287,14 +290,17 @@ func HeaderHandler(ac Auth) (HandlerFunc, error) { }, nil } +// IsAuth returns true if the context contains a user ID func IsAuth(c context.Context) bool { return c != nil && c.Value(core.UserIDKey) != nil } +// UserID returns the user ID from the context func UserID(c context.Context) interface{} { return c.Value(core.UserIDKey) } +// UserIDInt returns the user ID from the context as an int func UserIDInt(c context.Context) int { v, ok := UserID(c).(string) if !ok { diff --git a/auth/internal/rails/auth.go b/auth/internal/rails/auth.go index cc0dc325..a84206f5 100644 --- a/auth/internal/rails/auth.go +++ b/auth/internal/rails/auth.go @@ -30,6 +30,7 @@ type Auth struct { AuthSalt string } +// NewAuth creates a new Auth instance func NewAuth(version, secret string) (*Auth, error) { ra := &Auth{ Secret: secret, @@ -60,6 +61,7 @@ func NewAuth(version, secret string) (*Auth, error) { return ra, nil } +// ParseCookie parses the rails cookie and returns the user ID func (ra Auth) ParseCookie(cookie string) (userID string, err error) { var dcookie []byte @@ -87,6 +89,7 @@ func (ra Auth) ParseCookie(cookie string) (userID string, err error) { return } +// ParseCookie parses the rails cookie and returns the user ID func ParseCookie(cookie string) (string, error) { if cookie[0] != '{' { return getUserId4([]byte(cookie)) @@ -95,6 +98,7 @@ func ParseCookie(cookie string) (string, error) { return getUserId([]byte(cookie)) } +// getUserId extracts the user ID from the session data func getUserId(data []byte) (userID string, err error) { var sessionData map[string]interface{} @@ -135,10 +139,11 @@ func getUserId(data []byte) (userID string, err error) { return } +// getUserId4 extracts the user ID from the session data func getUserId4(data []byte) (userID string, err error) { sessionData, err := marshal.CreateMarshalledObject(data).GetAsMap() if err != nil { - return + return "", err } wardenData, ok := sessionData["warden.user.user.key"] diff --git a/auth/internal/rails/cookie.go b/auth/internal/rails/cookie.go index 72c225cb..8bc7fc8f 100644 --- a/auth/internal/rails/cookie.go +++ b/auth/internal/rails/cookie.go @@ -12,6 +12,7 @@ import ( "golang.org/x/crypto/pbkdf2" ) +// parseCookie decrypts and parses a Rails session cookie func parseCookie(cookie, secretKeyBase, salt, signSalt string) ([]byte, error) { return session.DecryptSignedCookie( cookie, @@ -22,6 +23,7 @@ func parseCookie(cookie, secretKeyBase, salt, signSalt string) ([]byte, error) { // {"session_id":"a71d6ffcd4ed5572ea2097f569eb95ef","warden.user.user.key":[[2],"$2a$11$q9Br7m4wJxQvF11hAHvTZO"],"_csrf_token":"HsYgrD2YBaWAabOYceN0hluNRnGuz49XiplmMPt43aY="} +// parseCookie52 decrypts and parses a Rails 5.2+ session cookie func parseCookie52(cookie, secretKeyBase, authSalt string) ([]byte, error) { ecookie, err := url.QueryUnescape(cookie) if err != nil { diff --git a/auth/jwt.go b/auth/jwt.go index 63e9c9a4..c3743103 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -14,6 +14,8 @@ const ( authHeader = "Authorization" ) +// JwtHandler is a middleware that checks for a JWT token in the cookie or the +// authorization header. If the token is found, it is validated and the claims func JwtHandler(ac Auth) (HandlerFunc, error) { jwtProvider, err := provider.NewProvider(ac.JWT) if err != nil { diff --git a/auth/provider/auth0.go b/auth/provider/auth0.go index 1a3e07b8..9125a53d 100644 --- a/auth/provider/auth0.go +++ b/auth/provider/auth0.go @@ -15,6 +15,7 @@ type Auth0Provider struct { issuer string } +// NewAuth0Provider creates a new Auth0 JWT provider func NewAuth0Provider(config JWTConfig) (*Auth0Provider, error) { key, err := getKey(config) if err != nil { @@ -27,12 +28,14 @@ func NewAuth0Provider(config JWTConfig) (*Auth0Provider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *Auth0Provider) KeyFunc() jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { return p.key, nil } } +// VerifyAudience checks if the audience claim is valid func (p *Auth0Provider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -40,6 +43,7 @@ func (p *Auth0Provider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer checks if the issuer claim is valid func (p *Auth0Provider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -47,6 +51,7 @@ func (p *Auth0Provider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *Auth0Provider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") diff --git a/auth/provider/firebase.go b/auth/provider/firebase.go index fceab5a4..c646e830 100644 --- a/auth/provider/firebase.go +++ b/auth/provider/firebase.go @@ -35,6 +35,7 @@ type FirebaseProvider struct { issuer string } +// NewFirebaseProvider creates a new Firebase JWT provider func NewFirebaseProvider(config JWTConfig) (*FirebaseProvider, error) { issuer := config.Issuer if issuer == "" { @@ -46,10 +47,12 @@ func NewFirebaseProvider(config JWTConfig) (*FirebaseProvider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *FirebaseProvider) KeyFunc() jwt.Keyfunc { return firebaseKeyFunction } +// VerifyAudience checks if the audience claim is valid func (p *FirebaseProvider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -57,6 +60,7 @@ func (p *FirebaseProvider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer checks if the issuer claim is valid func (p *FirebaseProvider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -64,6 +68,7 @@ func (p *FirebaseProvider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *FirebaseProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") @@ -85,6 +90,7 @@ func (e *firebaseKeyError) Error() string { return e.Message + " " + e.Err.Error() } +// firebaseKeyFunction returns the public key used to verify the JWT token func firebaseKeyFunction(token *jwt.Token) (interface{}, error) { kid, ok := token.Header["kid"] diff --git a/auth/provider/generic.go b/auth/provider/generic.go index 6d9f0306..0a868154 100644 --- a/auth/provider/generic.go +++ b/auth/provider/generic.go @@ -14,6 +14,7 @@ type GenericProvider struct { issuer string } +// NewGenericProvider creates a new generic JWT provider func NewGenericProvider(config JWTConfig) (*GenericProvider, error) { key, err := getKey(config) if err != nil { @@ -26,12 +27,14 @@ func NewGenericProvider(config JWTConfig) (*GenericProvider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *GenericProvider) KeyFunc() jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { return p.key, nil } } +// VerifyAudience verifies the audience claim of the JWT token func (p *GenericProvider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -39,6 +42,7 @@ func (p *GenericProvider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer verifies the issuer claim of the JWT token func (p *GenericProvider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -46,6 +50,7 @@ func (p *GenericProvider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *GenericProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") diff --git a/auth/provider/jwks.go b/auth/provider/jwks.go index 9a12a77a..25f3fc55 100644 --- a/auth/provider/jwks.go +++ b/auth/provider/jwks.go @@ -19,6 +19,7 @@ type keychainCache struct { semaphore int32 } +// newKeychainCache creates a new KeychainCache func newKeychainCache(jwksURL string, refreshInterval, minRefreshInterval int) *keychainCache { ar := jwk.NewAutoRefresh(context.Background()) if refreshInterval > 0 { @@ -34,6 +35,7 @@ func newKeychainCache(jwksURL string, refreshInterval, minRefreshInterval int) * } } +// getKey returns the key from the cache func (k *keychainCache) getKey(kid string) (interface{}, error) { set, err := k.keyCache.Fetch(context.TODO(), k.jwksURL) if err != nil { @@ -89,6 +91,7 @@ type JWKSProvider struct { cache *keychainCache } +// NewJWKSProvider creates a new JWKSProvider func NewJWKSProvider(config JWTConfig) (*JWKSProvider, error) { if config.JWKSURL == "" { return nil, errors.New("undefined JWKSURL") @@ -100,6 +103,7 @@ func NewJWKSProvider(config JWTConfig) (*JWKSProvider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *JWKSProvider) KeyFunc() jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { if token == nil { @@ -123,6 +127,7 @@ func (p *JWKSProvider) KeyFunc() jwt.Keyfunc { } } +// VerifyAudience checks if the audience claim is valid func (p *JWKSProvider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -130,6 +135,7 @@ func (p *JWKSProvider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer checks if the issuer claim is valid func (p *JWKSProvider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -137,6 +143,7 @@ func (p *JWKSProvider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *JWKSProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") diff --git a/auth/provider/provider.go b/auth/provider/provider.go index d322b9ee..efafa2f7 100644 --- a/auth/provider/provider.go +++ b/auth/provider/provider.go @@ -51,6 +51,7 @@ type JWTProvider interface { SetContextValues(context.Context, jwt.MapClaims) (context.Context, error) } +// NewProvider creates a new JWT provider based on the config values func NewProvider(config JWTConfig) (JWTProvider, error) { switch config.Provider { case "auth0": @@ -64,20 +65,21 @@ func NewProvider(config JWTConfig) (JWTProvider, error) { } } +// getKey returns the key used to verify the JWT token func getKey(config JWTConfig) (interface{}, error) { var key interface{} var err error switch { case config.PubKey != "": - pk := []byte(config.PubKey) + pubKey := []byte(config.PubKey) switch config.PubKeyType { case "ecdsa": - key, err = jwt.ParseECPublicKeyFromPEM(pk) + key, err = jwt.ParseECPublicKeyFromPEM(pubKey) case "rsa": - key, err = jwt.ParseRSAPublicKeyFromPEM(pk) + key, err = jwt.ParseRSAPublicKeyFromPEM(pubKey) default: - key, err = jwt.ParseECPublicKeyFromPEM(pk) + key, err = jwt.ParseECPublicKeyFromPEM(pubKey) } if err != nil { return nil, err diff --git a/auth/rails.go b/auth/rails.go index 6f4611a4..0eddd4bd 100644 --- a/auth/rails.go +++ b/auth/rails.go @@ -15,6 +15,7 @@ import ( "github.com/gomodule/redigo/redis" ) +// RailsHandler returns a handler that authenticates using a Rails session cookie func RailsHandler(ac Auth) (HandlerFunc, error) { ru := ac.Rails.URL @@ -29,6 +30,7 @@ func RailsHandler(ac Auth) (HandlerFunc, error) { return RailsCookieHandler(ac) } +// RailsRedisHandler returns a handler that authenticates using a Rails session cookie func RailsRedisHandler(ac Auth) (HandlerFunc, error) { cookie := ac.Cookie @@ -95,6 +97,7 @@ func RailsRedisHandler(ac Auth) (HandlerFunc, error) { }, nil } +// RailsMemcacheHandler returns a handler that authenticates using a Rails session cookie func RailsMemcacheHandler(ac Auth) (HandlerFunc, error) { cookie := ac.Cookie @@ -138,6 +141,7 @@ func RailsMemcacheHandler(ac Auth) (HandlerFunc, error) { }, nil } +// RailsCookieHandler returns a handler that authenticates using a Rails session cookie func RailsCookieHandler(ac Auth) (HandlerFunc, error) { cookie := ac.Cookie if len(cookie) == 0 { @@ -168,6 +172,7 @@ func RailsCookieHandler(ac Auth) (HandlerFunc, error) { }, nil } +// railsAuth returns a new rails auth instance func railsAuth(ac Auth) (*rails.Auth, error) { secret := ac.Rails.SecretKeyBase if len(secret) == 0 { diff --git a/cmd/cmd.go b/cmd/cmd.go index 96ce0944..cb96f887 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -29,6 +29,7 @@ var ( cpath string ) +// Cmd is the entry point for the CLI func Cmd() { log = newLogger(false).Sugar() @@ -64,14 +65,12 @@ func Cmd() { } } +// setup is a helper function to read the config file func setup(cpath string) { if conf != nil { return } - setupAgain(cpath) -} -func setupAgain(cpath string) { cp, err := filepath.Abs(cpath) if err != nil { log.Fatal(err) @@ -83,6 +82,7 @@ func setupAgain(cpath string) { } } +// initDB is a helper function to initialize the database connection func initDB(openDB bool) { var err error @@ -97,6 +97,7 @@ func initDB(openDB bool) { dbOpened = openDB } +// newLogger creates a new logger func newLogger(json bool) *zap.Logger { econf := zapcore.EncoderConfig{ MessageKey: "msg", diff --git a/cmd/cmd_admin.go b/cmd/cmd_admin.go index c601d196..c0646630 100644 --- a/cmd/cmd_admin.go +++ b/cmd/cmd_admin.go @@ -16,6 +16,7 @@ var ( secret string ) +// deployCmd deploys a new config or rolls back the active config func deployCmd() *cobra.Command { c := &cobra.Command{ Use: "deploy", @@ -36,6 +37,7 @@ func deployCmd() *cobra.Command { return c } +// initCmd initializes the admin database func initCmd() *cobra.Command { c := &cobra.Command{ Use: "init", @@ -45,6 +47,7 @@ func initCmd() *cobra.Command { return c } +// cmdInit initializes the admin database func cmdInit(cmd *cobra.Command, args []string) { setup(cpath) initDB(true) @@ -56,6 +59,7 @@ func cmdInit(cmd *cobra.Command, args []string) { log.Infof("init successful: %s", name) } +// cmdDeploy deploys a new config func cmdDeploy(cmd *cobra.Command, args []string) { if host == "" { log.Fatalf("--host is a required argument") @@ -79,6 +83,7 @@ func cmdDeploy(cmd *cobra.Command, args []string) { } } +// cmdRollback rolls back the active config func cmdRollback(cmd *cobra.Command, args []string) { if host == "" { log.Fatalf("--host is a required argument") diff --git a/cmd/cmd_db.go b/cmd/cmd_db.go index f1d5e896..cba08c40 100644 --- a/cmd/cmd_db.go +++ b/cmd/cmd_db.go @@ -4,6 +4,7 @@ import ( "github.com/spf13/cobra" ) +// dbCmd creates the db command func dbCmd() *cobra.Command { c := &cobra.Command{ Use: "db", @@ -30,6 +31,7 @@ func dbCmd() *cobra.Command { return c } +// cmdDBSeed seeds the database func cmdDBSetup(cmd *cobra.Command, args []string) { setup(cpath) diff --git a/cmd/cmd_migrate.go b/cmd/cmd_migrate.go index fa211142..4cc5e18c 100644 --- a/cmd/cmd_migrate.go +++ b/cmd/cmd_migrate.go @@ -14,6 +14,7 @@ import ( "golang.org/x/text/language" ) +// This is the cobra CLI command for the migrate subcommand func migrateCmd() *cobra.Command { c := &cobra.Command{ Use: "migrate", @@ -69,6 +70,7 @@ var newMigrationText = `-- Write your migrate up statements here -- Then delete the separator line above. ` +// cmdDBMigrate is the main function for the migrate subcommand func cmdDBMigrate(cmd *cobra.Command, args []string) { doneSomething := false @@ -93,7 +95,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) { m.Data = getMigrationVars(conf) - err = m.LoadMigrations(conf.RelPath(conf.MigrationsPath)) + err = m.LoadMigrations(conf.AbsolutePath(conf.MigrationsPath)) if err != nil { log.Fatalf("Failed to load migrations: %s", err) } @@ -197,6 +199,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) { } } +// cmdMigrateStatus is the function for the migrate status subcommand func cmdMigrateStatus(cmd *cobra.Command, args []string) { setup(cpath) initDB(true) @@ -212,7 +215,7 @@ func cmdMigrateStatus(cmd *cobra.Command, args []string) { m.Data = getMigrationVars(conf) - err = m.LoadMigrations(conf.RelPath(conf.MigrationsPath)) + err = m.LoadMigrations(conf.AbsolutePath(conf.MigrationsPath)) if err != nil { log.Fatalf("Failed to load migrations: %s", err) } @@ -238,6 +241,7 @@ func cmdMigrateStatus(cmd *cobra.Command, args []string) { status, mver, len(m.Migrations), conf.DB.Host, conf.DB.DBName) } +// cmdMigrateNew is the function for the migrate new subcommand func cmdMigrateNew(cmd *cobra.Command, args []string) { if len(args) != 1 { cmd.Help() //nolint:errcheck @@ -248,7 +252,7 @@ func cmdMigrateNew(cmd *cobra.Command, args []string) { initDB(false) name := args[0] - migrationsPath := conf.RelPath(conf.MigrationsPath) + migrationsPath := conf.AbsolutePath(conf.MigrationsPath) m, err := migrate.FindMigrations(migrationsPath) if err != nil { @@ -306,6 +310,7 @@ func ExtractErrorLine(source string, position int) (ErrorLineExtract, error) { return ele, nil } +// getMigrationVars returns the variables to be used in the migration templates func getMigrationVars(c *serv.Config) map[string]interface{} { en := cases.Title(language.English) diff --git a/cmd/cmd_new.go b/cmd/cmd_new.go index b7d52a41..12df832a 100644 --- a/cmd/cmd_new.go +++ b/cmd/cmd_new.go @@ -17,6 +17,7 @@ import ( var dbURL string +// This is the cobra CLI command for the new subcommand func newCmd() *cobra.Command { c := &cobra.Command{ Use: "new ", @@ -29,6 +30,7 @@ func newCmd() *cobra.Command { return c } +// cmdNew is the handler for the new subcommand func cmdNew(cmd *cobra.Command, args []string) { if len(args) != 1 { cmd.Help() //nolint:errcheck diff --git a/cmd/cmd_secrets.go b/cmd/cmd_secrets.go index da78b3b7..4116f7fd 100644 --- a/cmd/cmd_secrets.go +++ b/cmd/cmd_secrets.go @@ -82,7 +82,7 @@ func cmdSecrets() *cobra.Command { } else { setup(cpath) if conf.SecretsFile != "" { - fileName, err = filepath.Abs(conf.RelPath(conf.SecretsFile)) + fileName, err = filepath.Abs(conf.AbsolutePath(conf.SecretsFile)) } } diff --git a/cmd/cmd_seed.go b/cmd/cmd_seed.go index 2d427fb5..2c099326 100644 --- a/cmd/cmd_seed.go +++ b/cmd/cmd_seed.go @@ -24,6 +24,7 @@ import ( "github.com/spf13/cobra" ) +// cmdSeed is the cobra CLI for the seed subcommand func cmdDBSeed(cmd *cobra.Command, args []string) { setup(cpath) initDB(true) @@ -49,6 +50,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) { log.Infof("Seed script completed") } +// compileAndRunJS compiles and runs the seed script func compileAndRunJS(seed string, db *sql.DB) error { b, err := os.ReadFile(seed) if err != nil { @@ -168,7 +170,7 @@ func compileAndRunJS(seed string, db *sql.DB) error { return err } -// func runFunc(call goja.FunctionCall) { +// graphQLFunc is a helper function to run a GraphQL query func graphQLFunc(gj *core.GraphJin, query string, data interface{}, opt map[string]string) map[string]interface{} { ct := context.Background() @@ -214,6 +216,7 @@ type csvSource struct { i int } +// NewCSVSource creates a new CSV source func NewCSVSource(filename string, sep rune) (*csvSource, error) { f, err := os.Open(filename) if err != nil { @@ -272,6 +275,7 @@ func (c *csvSource) Values() ([]interface{}, error) { return vals, nil } +// isDigit checks if a string is a digit func isDigit(v string) bool { for i := range v { if v[i] < '0' || v[i] > '9' { @@ -285,6 +289,7 @@ func (c *csvSource) Err() error { return nil } +// importCSV imports a CSV file into a table func importCSV(table, filename string, sep string, db *sql.DB) int64 { log.Infof("Seeding table: %s, From file: %s", table, filename) @@ -353,6 +358,7 @@ func logFunc(args ...interface{}) { } } +// avatarURL returns a random avatar URL func avatarURL(size int) string { if size == 0 { size = 200 diff --git a/cmd/cmd_serv.go b/cmd/cmd_serv.go index fc2bd161..98f130cc 100644 --- a/cmd/cmd_serv.go +++ b/cmd/cmd_serv.go @@ -7,6 +7,7 @@ import ( var deployActive bool +// servCmd is the cobra CLI command for the serve subcommand func servCmd() *cobra.Command { c := &cobra.Command{ Use: "serve", @@ -18,6 +19,7 @@ func servCmd() *cobra.Command { return c } +// cmdServ is the handler for the serve subcommand func cmdServ(*cobra.Command, []string) { setup(cpath) diff --git a/cmd/cmd_version.go b/cmd/cmd_version.go index 245d6c4f..8738a051 100644 --- a/cmd/cmd_version.go +++ b/cmd/cmd_version.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" ) +// This is cobra CLI command for the version subcommand func versionCmd() *cobra.Command { c := &cobra.Command{ Use: "version", @@ -16,10 +17,12 @@ func versionCmd() *cobra.Command { return c } +// cmdVersion is the handler for the version subcommand func cmdVersion(cmd *cobra.Command, args []string) { fmt.Printf("%s\n", BuildDetails()) } +// BuildDetails returns the version information func BuildDetails() string { if version == "" { return ` diff --git a/conf/config.go b/conf/config.go index 3db9f5b0..dcd12654 100644 --- a/conf/config.go +++ b/conf/config.go @@ -13,6 +13,7 @@ type configInfo struct { Inherits string } +// NewConfig creates a new config object func NewConfig(configPath, configFile string) (c *core.Config, err error) { fs := core.NewOsFS(configPath) if c, err = NewConfigWithFS(fs, configFile); err != nil { @@ -21,6 +22,7 @@ func NewConfig(configPath, configFile string) (c *core.Config, err error) { return } +// NewConfigWithFS creates a new config object using the provided filesystem func NewConfigWithFS(fs core.FS, configFile string) (c *core.Config, err error) { c = &core.Config{FS: fs} var ci configInfo @@ -47,6 +49,7 @@ func NewConfigWithFS(fs core.FS, configFile string) (c *core.Config, err error) return } +// readConfig reads the config file and unmarshals it into the provided struct func readConfig(fs core.FS, configFile string, v interface{}) (err error) { format := filepath.Ext(configFile) diff --git a/core/api.go b/core/api.go index 5fd647ef..b615b028 100644 --- a/core/api.go +++ b/core/api.go @@ -46,36 +46,36 @@ const ( // GraphJin struct is an instance of the GraphJin engine it holds all the required information like // datase schemas, relationships, etc that the GraphQL to SQL compiler would need to do it's job. -type graphjin struct { - conf *Config - db *sql.DB - log *_log.Logger - fs FS - trace Tracer - dbtype string - dbinfo *sdata.DBInfo - schema *sdata.DBSchema - allowList *allow.List - encKey [32]byte - encKeySet bool - cache Cache - queries sync.Map - roles map[string]*Role - roleStmt string - roleStmtMD psql.Metadata - tmap map[string]qcode.TConfig - rtmap map[string]ResolverFn - rmap map[string]resItem - abacEnabled bool - qc *qcode.Compiler - pc *psql.Compiler - subs sync.Map - prod bool - prodSec bool - namespace string - pf []byte - opts []Option - done chan bool +type GraphjinEngine struct { + conf *Config + db *sql.DB + log *_log.Logger + fs FS + trace Tracer + dbtype string + dbinfo *sdata.DBInfo + schema *sdata.DBSchema + allowList *allow.List + encryptionKey [32]byte + encryptionKeySet bool + cache Cache + queries sync.Map + roles map[string]*Role + roleStatement string + roleStatementMetadata psql.Metadata + tmap map[string]qcode.TConfig + rtmap map[string]ResolverFn + rmap map[string]resItem + abacEnabled bool + qcodeCompiler *qcode.Compiler + psqlCompiler *psql.Compiler + subs sync.Map + prod bool + prodSec bool + namespace string + printFormat []byte + opts []Option + done chan bool } type GraphJin struct { @@ -83,7 +83,7 @@ type GraphJin struct { done chan bool } -type Option func(*graphjin) error +type Option func(*GraphjinEngine) error // NewGraphJin creates the GraphJin struct, this involves querying the database to learn its // schemas and relationships @@ -104,6 +104,7 @@ func NewGraphJin(conf *Config, db *sql.DB, options ...Option) (g *GraphJin, err return } +// NewGraphJinWithFS creates the GraphJin struct, this involves querying the database to learn its func NewGraphJinWithFS(conf *Config, db *sql.DB, fs FS, options ...Option) (g *GraphJin, err error) { g = &GraphJin{done: make(chan bool)} if err = g.newGraphJin(conf, db, nil, fs, options...); err != nil { @@ -116,6 +117,7 @@ func NewGraphJinWithFS(conf *Config, db *sql.DB, fs FS, options ...Option) (g *G return } +// newGraphJinWithDBInfo creates the GraphJin struct, this involves querying the database to learn its // it all starts here func (g *GraphJin) newGraphJin(conf *Config, db *sql.DB, @@ -129,18 +131,18 @@ func (g *GraphJin) newGraphJin(conf *Config, t := time.Now() - gj := &graphjin{ - conf: conf, - db: db, - dbinfo: dbinfo, - log: _log.New(os.Stdout, "", 0), - prod: conf.Production, - prodSec: conf.Production, - pf: []byte(fmt.Sprintf("gj/%x:", t.Unix())), - opts: options, - fs: fs, - trace: &tracer{}, - done: g.done, + gj := &GraphjinEngine{ + conf: conf, + db: db, + dbinfo: dbinfo, + log: _log.New(os.Stdout, "", 0), + prod: conf.Production, + prodSec: conf.Production, + printFormat: []byte(fmt.Sprintf("gj/%x:", t.Unix())), + opts: options, + fs: fs, + trace: &tracer{}, + done: g.done, } if gj.conf.DisableProdSecurity { @@ -193,8 +195,8 @@ func (g *GraphJin) newGraphJin(conf *Config, if conf.SecretKey != "" { sk := sha256.Sum256([]byte(conf.SecretKey)) - gj.encKey = sk - gj.encKeySet = true + gj.encryptionKey = sk + gj.encryptionKeySet = true } g.Store(gj) @@ -202,28 +204,31 @@ func (g *GraphJin) newGraphJin(conf *Config, } func OptionSetNamespace(namespace string) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { s.namespace = namespace return nil } } +// OptionSetFS sets the file system to be used by GraphJin func OptionSetFS(fs FS) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { s.fs = fs return nil } } +// OptionSetTrace sets the tracer to be used by GraphJin func OptionSetTrace(trace Tracer) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { s.trace = trace return nil } } +// OptionSetResolver sets the resolver function to be used by GraphJin func OptionSetResolver(name string, fn ResolverFn) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { if s.rtmap == nil { s.rtmap = s.newRTMap() } @@ -242,8 +247,8 @@ type Error struct { // Result struct contains the output of the GraphQL function this includes resulting json from the // database query and any error information type Result struct { - ns string - op qcode.QType + namespace string + operation qcode.QType name string sql string role string @@ -256,8 +261,8 @@ type Result struct { // Extensions *extensions `json:"extensions,omitempty"` } -// ReqConfig is used to pass request specific config values to the GraphQL and Subscribe functions. Dynamic variables can be set here. -type ReqConfig struct { +// RequestConfig is used to pass request specific config values to the GraphQL and Subscribe functions. Dynamic variables can be set here. +type RequestConfig struct { ns *string // APQKey is set when using GraphJin with automatic persisted queries @@ -271,12 +276,12 @@ type ReqConfig struct { } // SetNamespace is used to set namespace requests within a single instance of GraphJin. For example queries with the same name -func (rc *ReqConfig) SetNamespace(ns string) { +func (rc *RequestConfig) SetNamespace(ns string) { rc.ns = &ns } // GetNamespace is used to get the namespace requests within a single instance of GraphJin -func (rc *ReqConfig) GetNamespace() (string, bool) { +func (rc *RequestConfig) GetNamespace() (string, bool) { if rc.ns != nil { return *rc.ns, true } @@ -294,9 +299,9 @@ func (rc *ReqConfig) GetNamespace() (string, bool) { func (g *GraphJin) GraphQL(c context.Context, query string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) c1, span := gj.spanStart(c, "GraphJin Query") defer span.End() @@ -347,7 +352,7 @@ func (g *GraphJin) GraphQL(c context.Context, // if not production then save to allow list if !gj.prod && r.name != "IntrospectionQuery" { - if err = gj.saveToAllowList(resp.qc, resp.res.ns); err != nil { + if err = gj.saveToAllowList(resp.qc, resp.res.namespace); err != nil { return } } @@ -360,10 +365,10 @@ func (g *GraphJin) GraphQLTx(c context.Context, tx *sql.Tx, query string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { if rc == nil { - rc = &ReqConfig{Tx: tx} + rc = &RequestConfig{Tx: tx} } else { rc.Tx = tx } @@ -375,9 +380,9 @@ func (g *GraphJin) GraphQLTx(c context.Context, func (g *GraphJin) GraphQLByName(c context.Context, name string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) c1, span := gj.spanStart(c, "GraphJin Query") defer span.End() @@ -401,78 +406,79 @@ func (g *GraphJin) GraphQLByNameTx(c context.Context, tx *sql.Tx, name string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { if rc == nil { - rc = &ReqConfig{Tx: tx} + rc = &RequestConfig{Tx: tx} } else { rc.Tx = tx } return g.GraphQLByName(c, name, vars, rc) } -type graphqlReq struct { - ns string - op qcode.QType - name string - query []byte - vars json.RawMessage - aschema map[string]json.RawMessage - rc *ReqConfig +type GraphqlReq struct { + namespace string + operation qcode.QType + name string + query []byte + vars json.RawMessage + aschema map[string]json.RawMessage + requestconfig *RequestConfig } -type graphqlResp struct { +type GraphqlResponse struct { res Result qc *qcode.QCode } -func (gj *graphjin) newGraphqlReq(rc *ReqConfig, +// newGraphqlReq creates a new GraphQL request +func (gj *GraphjinEngine) newGraphqlReq(rc *RequestConfig, op string, name string, query []byte, vars json.RawMessage, -) (r graphqlReq) { - r = graphqlReq{ - op: qcode.GetQTypeByName(op), - name: name, - query: query, - vars: vars, +) (r GraphqlReq) { + r = GraphqlReq{ + operation: qcode.GetQTypeByName(op), + name: name, + query: query, + vars: vars, } if rc != nil { - r.rc = rc + r.requestconfig = rc } if rc != nil && rc.ns != nil { - r.ns = *rc.ns + r.namespace = *rc.ns } else { - r.ns = gj.namespace + r.namespace = gj.namespace } return } // Set is used to set the namespace, operation type, name and query for the GraphQL request -func (r *graphqlReq) Set(item allow.Item) { - r.ns = item.Namespace - r.op = qcode.GetQTypeByName(item.Operation) +func (r *GraphqlReq) Set(item allow.Item) { + r.namespace = item.Namespace + r.operation = qcode.GetQTypeByName(item.Operation) r.name = item.Name r.query = item.Query r.aschema = item.ActionJSON } // GraphQL function is our main function it takes a GraphQL query compiles it -func (gj *graphjin) queryWithResult(c context.Context, r graphqlReq) (res *Result, err error) { +func (gj *GraphjinEngine) queryWithResult(c context.Context, r GraphqlReq) (res *Result, err error) { resp, err := gj.query(c, r) return &resp.res, err } // GraphQL function is our main function it takes a GraphQL query compiles it -func (gj *graphjin) query(c context.Context, r graphqlReq) ( - resp graphqlResp, err error, +func (gj *GraphjinEngine) query(c context.Context, r GraphqlReq) ( + resp GraphqlResponse, err error, ) { resp.res = Result{ - ns: r.ns, - op: r.op, - name: r.name, + namespace: r.namespace, + operation: r.operation, + name: r.name, } if !gj.prodSec && r.name == "IntrospectionQuery" { @@ -480,12 +486,12 @@ func (gj *graphjin) query(c context.Context, r graphqlReq) ( return } - if r.op == qcode.QTSubscription { + if r.operation == qcode.QTSubscription { err = errors.New("use 'core.Subscribe' for subscriptions") return } - if r.op == qcode.QTMutation && gj.schema.DBType() == "mysql" { + if r.operation == qcode.QTMutation && gj.schema.DBType() == "mysql" { err = errors.New("mysql: mutations not supported") return } @@ -519,15 +525,16 @@ func (g *GraphJin) Reload() error { return g.reload(nil) } +// reload redoes database discover and reinitializes GraphJin. func (g *GraphJin) reload(di *sdata.DBInfo) (err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) err = g.newGraphJin(gj.conf, gj.db, di, gj.fs, gj.opts...) return } // IsProd return true for production mode or false for development mode func (g *GraphJin) IsProd() bool { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) return gj.prod } @@ -546,6 +553,7 @@ func Operation(query string) (h Header, err error) { return } +// getFS returns the file system to be used by GraphJin func getFS(conf *Config) (fs FS, err error) { if v, ok := conf.FS.(FS); ok { fs = v @@ -561,6 +569,7 @@ func getFS(conf *Config) (fs FS, err error) { return } +// newError creates a new error list func newError(err error) (errList []Error) { errList = []Error{{Message: err.Error()}} return diff --git a/core/args.go b/core/args.go index ae27b60d..755fa7d2 100644 --- a/core/args.go +++ b/core/args.go @@ -18,10 +18,10 @@ type args struct { cindx int // index of cursor arg } -func (gj *graphjin) argList(c context.Context, +func (gj *GraphjinEngine) argList(c context.Context, md psql.Metadata, fields map[string]json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, buildJSON bool, ) (ar args, err error) { ar = args{cindx: -1} diff --git a/core/cache.go b/core/cache.go index a605d13b..52a03d64 100644 --- a/core/cache.go +++ b/core/cache.go @@ -8,11 +8,13 @@ type Cache struct { cache *lru.TwoQueueCache } -func (gj *graphjin) initCache() (err error) { +// initCache initializes the cache +func (gj *GraphjinEngine) initCache() (err error) { gj.cache.cache, err = lru.New2Q(500) return } +// Get returns the value from the cache func (c Cache) Get(key string) (val []byte, fromCache bool) { if v, ok := c.cache.Get(key); ok { val = v.([]byte) @@ -21,6 +23,7 @@ func (c Cache) Get(key string) (val []byte, fromCache bool) { return } +// Set sets the value in the cache func (c Cache) Set(key string, val []byte) { c.cache.Add(key, val) } diff --git a/core/config.go b/core/config.go index e84f0011..c13e56ed 100644 --- a/core/config.go +++ b/core/config.go @@ -267,7 +267,7 @@ type ResolverReq struct { ID string Sel *qcode.Select Log *log.Logger - *ReqConfig + *RequestConfig } // AddRoleTable function is a helper function to make it easy to add per-table diff --git a/core/core.go b/core/core.go index 8febb823..b00ef3f4 100644 --- a/core/core.go +++ b/core/core.go @@ -56,7 +56,7 @@ const ( // Duration time.Duration `json:"duration"` // } -func (gj *graphjin) getIntroResult() (data json.RawMessage, err error) { +func (gj *GraphjinEngine) getIntroResult() (data json.RawMessage, err error) { var ok bool if data, ok = gj.cache.Get("_intro"); ok { return @@ -69,7 +69,7 @@ func (gj *graphjin) getIntroResult() (data json.RawMessage, err error) { } // Initializes the database discovery process on graphjin -func (gj *graphjin) initDiscover() (err error) { +func (gj *GraphjinEngine) initDiscover() (err error) { switch gj.conf.DBType { case "": gj.dbtype = "postgres" @@ -86,7 +86,7 @@ func (gj *graphjin) initDiscover() (err error) { } // Private method that does the actual database discovery for initDiscover -func (gj *graphjin) _initDiscover() (err error) { +func (gj *GraphjinEngine) _initDiscover() (err error) { if gj.prod && gj.conf.EnableSchema { b, err := gj.fs.Get("db.graphql") if err != nil { @@ -132,14 +132,14 @@ func (gj *graphjin) _initDiscover() (err error) { } // Initializes the database schema on graphjin -func (gj *graphjin) initSchema() error { +func (gj *GraphjinEngine) initSchema() error { if err := gj._initSchema(); err != nil { return fmt.Errorf("%s: %w", gj.dbtype, err) } return nil } -func (gj *graphjin) _initSchema() (err error) { +func (gj *GraphjinEngine) _initSchema() (err error) { if len(gj.dbinfo.Tables) == 0 { return fmt.Errorf("no tables found in database") } @@ -178,7 +178,7 @@ func (gj *graphjin) _initSchema() (err error) { return } -func (gj *graphjin) initIntro() (err error) { +func (gj *GraphjinEngine) initIntro() (err error) { if !gj.prod && gj.conf.EnableIntrospection { var introJSON json.RawMessage introJSON, err = gj.getIntroResult() @@ -194,7 +194,7 @@ func (gj *graphjin) initIntro() (err error) { } // Initializes the qcode compilers -func (gj *graphjin) initCompilers() (err error) { +func (gj *GraphjinEngine) initCompilers() (err error) { qcc := qcode.Config{ TConfig: gj.tmap, DefaultBlock: gj.conf.DefaultBlock, @@ -206,29 +206,29 @@ func (gj *graphjin) initCompilers() (err error) { Validators: valid.Validators, } - gj.qc, err = qcode.NewCompiler(gj.schema, qcc) + gj.qcodeCompiler, err = qcode.NewCompiler(gj.schema, qcc) if err != nil { return } - if err = addRoles(gj.conf, gj.qc); err != nil { + if err = addRoles(gj.conf, gj.qcodeCompiler); err != nil { return } - gj.pc = psql.NewCompiler(psql.Config{ + gj.psqlCompiler = psql.NewCompiler(psql.Config{ Vars: gj.conf.Vars, DBType: gj.schema.DBType(), DBVersion: gj.schema.DBVersion(), - SecPrefix: gj.pf, + SecPrefix: gj.printFormat, EnableCamelcase: gj.conf.EnableCamelcase, }) return } -func (gj *graphjin) executeRoleQuery(c context.Context, +func (gj *GraphjinEngine) executeRoleQuery(c context.Context, conn *sql.Conn, vmap map[string]json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (role string, err error) { if c.Value(UserIDKey) == nil { role = "anon" @@ -236,7 +236,7 @@ func (gj *graphjin) executeRoleQuery(c context.Context, } ar, err := gj.argList(c, - gj.roleStmtMD, + gj.roleStatementMetadata, vmap, rc, false) @@ -266,9 +266,9 @@ func (gj *graphjin) executeRoleQuery(c context.Context, err = retryOperation(c1, func() error { var row *sql.Row if rc != nil && rc.Tx != nil { - row = rc.Tx.QueryRowContext(c1, gj.roleStmt, ar.values...) + row = rc.Tx.QueryRowContext(c1, gj.roleStatement, ar.values...) } else { - row = conn.QueryRowContext(c1, gj.roleStmt, ar.values...) + row = conn.QueryRowContext(c1, gj.roleStatement, ar.values...) } return row.Scan(&role) }) @@ -283,7 +283,7 @@ func (gj *graphjin) executeRoleQuery(c context.Context, // Returns the operation type for the query result func (r *Result) Operation() OpType { - switch r.op { + switch r.operation { case qcode.QTQuery: return OpQuery @@ -297,12 +297,12 @@ func (r *Result) Operation() OpType { // Returns the namespace for the query result func (r *Result) Namespace() string { - return r.ns + return r.namespace } // Returns the operation name for the query result func (r *Result) OperationName() string { - return r.op.String() + return r.operation.String() } // Returns the query name for the query result @@ -368,6 +368,7 @@ func (r *Result) CacheControl() string { // append(c.res.Extensions.Tracing.Execution.Resolvers, tr) // } +// debugLogStmt logs the query statement for debugging func (s *gstate) debugLogStmt() { st := s.cs.st @@ -386,7 +387,7 @@ func (s *gstate) debugLogStmt() { } // Saved the query qcode to the allow list -func (gj *graphjin) saveToAllowList(qc *qcode.QCode, ns string) (err error) { +func (gj *GraphjinEngine) saveToAllowList(qc *qcode.QCode, ns string) (err error) { if gj.conf.DisableAllowList { return nil } @@ -416,7 +417,7 @@ func (gj *graphjin) saveToAllowList(qc *qcode.QCode, ns string) (err error) { } // Starts tracing with the given name -func (gj *graphjin) spanStart(c context.Context, name string) (context.Context, Spaner) { +func (gj *GraphjinEngine) spanStart(c context.Context, name string) (context.Context, Spaner) { return gj.trace.Start(c, name) } diff --git a/core/crypt.go b/core/crypt.go index bb8c9342..50277977 100644 --- a/core/crypt.go +++ b/core/crypt.go @@ -7,6 +7,11 @@ import ( "encoding/base64" ) +// encryptValues encrypts the values in the data using the given key +// data: the data to encrypt +// encPrefix: the prefix to search for the values to encrypt +// decPrefix: the prefix to replace the values with +// nonce: the nonce to use for encryption func encryptValues( data, encPrefix, decPrefix, nonce []byte, key [32]byte) ([]byte, error) { @@ -78,6 +83,10 @@ func encryptValues( return b.Bytes(), nil } +// decryptValues decrypts the values in the data using the given key +// data: the data to decrypt +// prefix: the prefix to search for the values to decrypt +// key: the key to use for decryption func decryptValues(data, prefix []byte, key [32]byte) ([]byte, error) { var s, e int if e = bytes.Index(data[s:], prefix); e == -1 { @@ -151,6 +160,7 @@ func decryptValues(data, prefix []byte, key [32]byte) ([]byte, error) { return b.Bytes(), nil } +// firstCursorValue returns the first cursor value in the data func firstCursorValue(data []byte, prefix []byte) []byte { var buf [100]byte pf := append(buf[:0], prefix...) diff --git a/core/gstate.go b/core/gstate.go index bb5ed17f..8d305bfd 100644 --- a/core/gstate.go +++ b/core/gstate.go @@ -16,8 +16,8 @@ import ( ) type gstate struct { - gj *graphjin - r graphqlReq + gj *GraphjinEngine + r GraphqlReq cs *cstate vmap map[string]json.RawMessage data []byte @@ -40,7 +40,7 @@ type stmt struct { sql string } -func newGState(c context.Context, gj *graphjin, r graphqlReq) (s gstate, err error) { +func newGState(c context.Context, gj *GraphjinEngine, r GraphqlReq) (s gstate, err error) { s.gj = gj s.r = r @@ -58,7 +58,7 @@ func newGState(c context.Context, gj *graphjin, r graphqlReq) (s gstate, err err // convert variable json to a go map also decrypted encrypted values if len(r.vars) != 0 { var vars json.RawMessage - vars, err = decryptValues(r.vars, decPrefix, s.gj.encKey) + vars, err = decryptValues(r.vars, decPrefix, s.gj.encryptionKey) if err != nil { return } @@ -113,16 +113,16 @@ func (s *gstate) compileQueryForRole() (err error) { vars = s.vmap } - if st.qc, err = s.gj.qc.Compile( + if st.qc, err = s.gj.qcodeCompiler.Compile( s.r.query, vars, s.role, - s.r.ns); err != nil { + s.r.namespace); err != nil { return } var w bytes.Buffer - if st.md, err = s.gj.pc.Compile(&w, st.qc); err != nil { + if st.md, err = s.gj.psqlCompiler.Compile(&w, st.qc); err != nil { return } @@ -254,7 +254,7 @@ func (s *gstate) execute(c context.Context, conn *sql.Conn) (err error) { if span.IsRecording() { span.SetAttributesString( - StringAttr{"query.namespace", s.r.ns}, + StringAttr{"query.namespace", s.r.namespace}, StringAttr{"query.operation", cs.st.qc.Type.String()}, StringAttr{"query.name", cs.st.qc.Name}, StringAttr{"query.role", cs.st.role}) @@ -270,25 +270,25 @@ func (s *gstate) execute(c context.Context, conn *sql.Conn) (err error) { s.dhash = sha256.Sum256(s.data) s.data, err = encryptValues(s.data, - s.gj.pf, decPrefix, s.dhash[:], s.gj.encKey) + s.gj.printFormat, decPrefix, s.dhash[:], s.gj.encryptionKey) return } func (s *gstate) executeRoleQuery(c context.Context, conn *sql.Conn) (err error) { - s.role, err = s.gj.executeRoleQuery(c, conn, s.vmap, s.r.rc) + s.role, err = s.gj.executeRoleQuery(c, conn, s.vmap, s.r.requestconfig) return } func (s *gstate) argList(c context.Context) (args args, err error) { - args, err = s.gj.argList(c, s.cs.st.md, s.vmap, s.r.rc, false) + args, err = s.gj.argList(c, s.cs.st.md, s.vmap, s.r.requestconfig, false) return } func (s *gstate) argListForSub(c context.Context, vmap map[string]json.RawMessage, ) (args args, err error) { - args, err = s.gj.argList(c, s.cs.st.md, vmap, s.r.rc, true) + args, err = s.gj.argList(c, s.cs.st.md, vmap, s.r.requestconfig, true) return } @@ -354,13 +354,13 @@ func (s *gstate) qcode() (qc *qcode.QCode) { } func (s *gstate) tx() (tx *sql.Tx) { - if s.r.rc != nil { - tx = s.r.rc.Tx + if s.r.requestconfig != nil { + tx = s.r.requestconfig.Tx } return } func (s *gstate) key() (key string) { - key = s.r.ns + s.r.name + s.role + key = s.r.namespace + s.r.name + s.role return } diff --git a/core/init.go b/core/init.go index aa20d408..9fcbb4a2 100644 --- a/core/init.go +++ b/core/init.go @@ -11,17 +11,17 @@ import ( ) // Initializes the graphjin instance with the config -func (gj *graphjin) initConfig() error { +func (gj *GraphjinEngine) initConfig() error { c := gj.conf - tm := make(map[string]struct{}) + tableMap := make(map[string]struct{}) - for _, t := range c.Tables { - k := t.Schema + t.Name - if _, ok := tm[k]; ok { - return fmt.Errorf("duplicate table found: %s", t.Name) + for _, table := range c.Tables { + k := table.Schema + table.Name + if _, ok := tableMap[k]; ok { + return fmt.Errorf("duplicate table found: %s", table.Name) } - tm[k] = struct{}{} + tableMap[k] = struct{}{} } for k, v := range c.Vars { @@ -84,7 +84,8 @@ func (gj *graphjin) initConfig() error { return nil } -func (gj *graphjin) addTableInfo(t Table) error { +// addTableInfo adds table info to the compiler +func (gj *GraphjinEngine) addTableInfo(t Table) error { obm := map[string][][2]string{} for k, ob := range t.OrderBy { @@ -103,6 +104,7 @@ func (gj *graphjin) addTableInfo(t Table) error { return nil } +// getDBTableAliases returns a map of table aliases func getDBTableAliases(c *Config) map[string][]string { m := make(map[string][]string, len(c.Tables)) @@ -116,7 +118,8 @@ func getDBTableAliases(c *Config) map[string][]string { return m } -func addTables(conf *Config, di *sdata.DBInfo) error { +// addTables adds tables to the database info +func addTables(conf *Config, dbInfo *sdata.DBInfo) error { var err error for _, t := range conf.Tables { @@ -126,13 +129,13 @@ func addTables(conf *Config, di *sdata.DBInfo) error { } switch t.Type { case "json", "jsonb": - err = addJsonTable(conf, di, t) + err = addJsonTable(conf, dbInfo, t) case "polymorphic": - err = addVirtualTable(conf, di, t) + err = addVirtualTable(conf, dbInfo, t) default: - err = updateTable(conf, di, t) + err = updateTable(conf, dbInfo, t) } if err != nil { @@ -143,14 +146,15 @@ func addTables(conf *Config, di *sdata.DBInfo) error { return nil } -func updateTable(conf *Config, di *sdata.DBInfo, t Table) error { - t1, err := di.GetTable(t.Schema, t.Name) +// updateTable updates the table info in the database info +func updateTable(conf *Config, dbInfo *sdata.DBInfo, table Table) error { + t1, err := dbInfo.GetTable(table.Schema, table.Name) if err != nil { return fmt.Errorf("table: %w", err) } - for _, c := range t.Columns { - c1, err := di.GetColumn(t.Schema, t.Name, c.Name) + for _, c := range table.Columns { + c1, err := dbInfo.GetColumn(table.Schema, table.Name, c.Name) if err != nil { return err } @@ -168,18 +172,19 @@ func updateTable(conf *Config, di *sdata.DBInfo, t Table) error { return nil } -func addJsonTable(conf *Config, di *sdata.DBInfo, t Table) error { +// addJsonTable adds a json table to the database info +func addJsonTable(conf *Config, dbInfo *sdata.DBInfo, table Table) error { // This is for jsonb column that want to be a table. - if t.Table == "" { - return fmt.Errorf("json table: set the 'table' for column '%s'", t.Name) + if table.Table == "" { + return fmt.Errorf("json table: set the 'table' for column '%s'", table.Name) } - bc, err := di.GetColumn(t.Schema, t.Table, t.Name) + bc, err := dbInfo.GetColumn(table.Schema, table.Table, table.Name) if err != nil { return fmt.Errorf("json table: %w", err) } - bt, err := di.GetTable(bc.Schema, bc.Table) + bt, err := dbInfo.GetTable(bc.Schema, bc.Table) if err != nil { return fmt.Errorf("json table: %w", err) } @@ -187,23 +192,23 @@ func addJsonTable(conf *Config, di *sdata.DBInfo, t Table) error { if bc.Type != "json" && bc.Type != "jsonb" { return fmt.Errorf( "json table: column '%s' in table '%s' is of type '%s'. Only JSON or JSONB is valid", - t.Name, t.Table, bc.Type) + table.Name, table.Table, bc.Type) } - columns := make([]sdata.DBColumn, 0, len(t.Columns)) + columns := make([]sdata.DBColumn, 0, len(table.Columns)) - for i := range t.Columns { - c := t.Columns[i] + for i := range table.Columns { + c := table.Columns[i] columns = append(columns, sdata.DBColumn{ ID: -1, Schema: bc.Schema, - Table: t.Name, + Table: table.Name, Name: c.Name, Type: c.Type, }) if c.Type == "" { return fmt.Errorf("json table: type parameter missing for column: %s.%s'", - t.Name, c.Name) + table.Name, c.Name) } } @@ -216,14 +221,15 @@ func addJsonTable(conf *Config, di *sdata.DBInfo, t Table) error { Type: bc.Type, } - nt := sdata.NewDBTable(bc.Schema, t.Name, bc.Type, columns) + nt := sdata.NewDBTable(bc.Schema, table.Name, bc.Type, columns) nt.PrimaryCol = col1 nt.SecondaryCol = bt.PrimaryCol - di.AddTable(nt) + dbInfo.AddTable(nt) return nil } +// addVirtualTable adds a virtual table to the database info func addVirtualTable(conf *Config, di *sdata.DBInfo, t Table) error { if len(t.Columns) == 0 { return fmt.Errorf("polymorphic table: no id column specified") @@ -249,6 +255,7 @@ func addVirtualTable(conf *Config, di *sdata.DBInfo, t Table) error { return nil } +// addForeignKeys adds foreign keys to the database info func addForeignKeys(conf *Config, di *sdata.DBInfo) error { for _, t := range conf.Tables { if t.Type == "polymorphic" { @@ -266,6 +273,7 @@ func addForeignKeys(conf *Config, di *sdata.DBInfo) error { return nil } +// addForeignKey adds a foreign key to the database info func addForeignKey(conf *Config, di *sdata.DBInfo, c Column, t Table) error { c1, err := di.GetColumn(t.Schema, t.Name, c.Name) if err != nil { @@ -310,6 +318,7 @@ func addForeignKey(conf *Config, di *sdata.DBInfo, c Column, t Table) error { return nil } +// addRoles adds roles to the compiler func addRoles(c *Config, qc *qcode.Compiler) error { for _, r := range c.Roles { for _, t := range r.Tables { @@ -322,6 +331,7 @@ func addRoles(c *Config, qc *qcode.Compiler) error { return nil } +// addRole adds a role to the compiler func addRole(qc *qcode.Compiler, r Role, t RoleTable, defaultBlock bool) error { ro := false // read-only @@ -392,10 +402,12 @@ func addRole(qc *qcode.Compiler, r Role, t RoleTable, defaultBlock bool) error { }) } +// GetTable returns a table from the role func (r *Role) GetTable(schema, name string) *RoleTable { return r.tm[name] } +// getFK returns the foreign key for the column func (c *Column) getFK(defaultSchema string) ([3]string, bool) { var ret [3]string var ok bool @@ -412,10 +424,12 @@ func (c *Column) getFK(defaultSchema string) ([3]string, bool) { return ret, ok } +// sanitize trims the value func sanitize(value string) string { return strings.TrimSpace(value) } +// isASCII checks if the string is ASCII func isASCII(s string) (int, bool) { for i := 0; i < len(s); i++ { if s[i] > unicode.MaxASCII { @@ -425,7 +439,8 @@ func isASCII(s string) (int, bool) { return -1, true } -func (gj *graphjin) initAllowList() (err error) { +// initAllowList initializes the allow list +func (gj *GraphjinEngine) initAllowList() (err error) { gj.allowList, err = allow.New( gj.log, gj.fs, diff --git a/core/internal/allow/allow.go b/core/internal/allow/allow.go index 3df9fb0b..b069bef9 100644 --- a/core/internal/allow/allow.go +++ b/core/internal/allow/allow.go @@ -22,7 +22,7 @@ type FS interface { var ErrUnknownGraphQLQuery = errors.New("unknown graphql query") const ( - queryPath = "/queries" + QUERY_PATH = "/queries" ) type Item struct { @@ -45,6 +45,7 @@ type List struct { fs FS } +// New creates a new allow list func New(log *_log.Logger, fs FS, readOnly bool) (al *List, err error) { if fs == nil { return nil, fmt.Errorf("no filesystem defined for the allow list") @@ -77,6 +78,7 @@ func New(log *_log.Logger, fs FS, readOnly bool) (al *List, err error) { return al, err } +// Set adds a new query to the allow list func (al *List) Set(item Item) error { if al.saveChan == nil { return errors.New("allow list is read-only") @@ -90,6 +92,7 @@ func (al *List) Set(item Item) error { return nil } +// GetByName returns a query by name func (al *List) GetByName(name string, useCache bool) (item Item, err error) { if useCache { if v, ok := al.cache.Get(name); ok { @@ -98,26 +101,27 @@ func (al *List) GetByName(name string, useCache bool) (item Item, err error) { } } - fp := filepath.Join(queryPath, name) + fp := filepath.Join(QUERY_PATH, name) var ok bool if ok, err = al.fs.Exists((fp + ".gql")); err != nil { return } else if ok { - item, err = al.get(queryPath, name, ".gql", useCache) + item, err = al.get(QUERY_PATH, name, ".gql", useCache) return } if ok, err = al.fs.Exists((fp + ".graphql")); err != nil { return } else if ok { - item, err = al.get(queryPath, name, ".gql", useCache) + item, err = al.get(QUERY_PATH, name, ".gql", useCache) } else { err = ErrUnknownGraphQLQuery } return } +// get returns a query by name func (al *List) get(queryPath, name, ext string, useCache bool) (item Item, err error) { queryNS, queryName := splitName(name) @@ -161,6 +165,7 @@ func (al *List) get(queryPath, name, ext string, useCache bool) (item Item, err return } +// save saves a query to the allow list func (al *List) save(item Item) (err error) { item.Name = strings.TrimSpace(item.Name) if item.Name == "" { @@ -170,6 +175,7 @@ func (al *List) save(item Item) (err error) { return al.saveItem(item) } +// saveItem saves a query to the allow list func (al *List) saveItem(item Item) (err error) { var queryFile string if item.Namespace != "" { @@ -196,7 +202,7 @@ func (al *List) saveItem(item Item) (err error) { fmap[fragFile] = struct{}{} } - ff := filepath.Join(queryPath, "fragments", (fragFile + ".gql")) + ff := filepath.Join(QUERY_PATH, "fragments", (fragFile + ".gql")) err = al.fs.Put(ff, []byte(f.Value)) if err != nil { return @@ -207,7 +213,7 @@ func (al *List) saveItem(item Item) (err error) { } buf.Write(bytes.TrimSpace(item.Query)) - qf := filepath.Join(queryPath, (queryFile + ".gql")) + qf := filepath.Join(QUERY_PATH, (queryFile + ".gql")) err = al.fs.Put(qf, bytes.TrimSpace(buf.Bytes())) if err != nil { return @@ -215,7 +221,7 @@ func (al *List) saveItem(item Item) (err error) { if len(item.ActionJSON) != 0 { var vars []byte - jf := filepath.Join(queryPath, (queryFile + ".json")) + jf := filepath.Join(QUERY_PATH, (queryFile + ".json")) vars, err = json.MarshalIndent(item.ActionJSON, "", " ") if err != nil { return @@ -225,12 +231,13 @@ func (al *List) saveItem(item Item) (err error) { return } -func splitName(v string) (string, string) { - i := strings.LastIndex(v, ".") +// splitName splits a name into namespace and name +func splitName(name string) (string, string) { + i := strings.LastIndex(name, ".") if i == -1 { - return "", v - } else if i < len(v)-1 { - return v[:i], v[(i + 1):] + return "", name + } else if i < len(name)-1 { + return name[:i], name[(i + 1):] } return "", "" } diff --git a/core/internal/allow/gql.go b/core/internal/allow/gql.go index 5ccf52b5..a18301f9 100644 --- a/core/internal/allow/gql.go +++ b/core/internal/allow/gql.go @@ -10,6 +10,7 @@ import ( var incRe = regexp.MustCompile(`(?m)#import \"(.+)\"`) +// readGQL reads a graphql file and resolves all imports func readGQL(fs FS, fname string) (gql []byte, err error) { var b bytes.Buffer @@ -28,6 +29,7 @@ func readGQL(fs FS, fname string) (gql []byte, err error) { return } +// parseGQL parses a graphql file and resolves all imports func parseGQL(fs FS, fname string, r io.Writer) (err error) { b, err := fs.Get(fname) if err != nil { diff --git a/core/internal/assert/assert.go b/core/internal/assert/assert.go index 73f889fc..f9377577 100644 --- a/core/internal/assert/assert.go +++ b/core/internal/assert/assert.go @@ -5,12 +5,14 @@ import ( "testing" ) +// Equals compares two values func Equals(t *testing.T, exp, got interface{}) { if !reflect.DeepEqual(exp, got) { t.Errorf("expected %v, got %v", exp, got) } } +// Empty checks if a slice is empty func Empty(t *testing.T, got interface{}) { val := reflect.ValueOf(got) if val.Kind() != reflect.Slice { @@ -24,12 +26,14 @@ func Empty(t *testing.T, got interface{}) { } } +// NoError checks if an error is nil func NoError(t *testing.T, err error) { if err != nil { t.Errorf("no errror expected, got %s", err.Error()) } } +// NoErrorFatal checks if an error is nil and fails the test func NoErrorFatal(t *testing.T, err error) { if err != nil { t.Fatalf("no errror expected, got %s", err.Error()) diff --git a/core/internal/graph/lex.go b/core/internal/graph/lex.go index bdb481cd..8257c6fc 100644 --- a/core/internal/graph/lex.go +++ b/core/internal/graph/lex.go @@ -79,8 +79,8 @@ var punctuators = map[rune]MType{ const eof = -1 -// stateFn represents the state of the scanner as a function that returns the next state. -type stateFn func(*lexer) stateFn +// StateFn represents the state of the scanner as a function that returns the next state. +type StateFn func(*lexer) StateFn // lexer holds the state of the scanner. type lexer struct { @@ -96,6 +96,7 @@ type lexer struct { var zeroLex = lexer{} +// Reset resets the lexer to scan a new input string. func (l *lexer) Reset() { *l = zeroLex } @@ -133,6 +134,7 @@ func (l *lexer) backup() { } } +// current returns the current bytes of the input. func (l *lexer) current() []byte { return l.input[l.start:l.pos] } @@ -151,6 +153,7 @@ func (l *lexer) emit(t MType) { l.start = l.pos } +// emitL passes an item back to the client and lowercases the value. func (l *lexer) emitL(t MType) { lowercase(l.current()) l.emit(t) @@ -199,7 +202,7 @@ func (l *lexer) acceptRun(valid []byte) { // errorf returns an error token and terminates the scan by passing // back a nil pointer that will be the next state, terminating l.nextItem. -func (l *lexer) errorf(format string, args ...interface{}) stateFn { +func (l *lexer) errorf(format string, args ...interface{}) StateFn { l.err = fmt.Errorf(format, args...) l.items = append(l.items, item{itemError, l.start, l.input[l.start:l.pos], l.line}) return nil @@ -233,7 +236,7 @@ func (l *lexer) run() { } // lexInsideAction scans the elements inside action delimiters. -func lexRoot(l *lexer) stateFn { +func lexRoot(l *lexer) StateFn { r := l.next() switch { @@ -287,7 +290,7 @@ func lexRoot(l *lexer) stateFn { } // lexName scans a name. -func lexName(l *lexer) stateFn { +func lexName(l *lexer) StateFn { for { r := l.next() @@ -317,7 +320,7 @@ func lexName(l *lexer) stateFn { } // lexString scans a string. -func lexString(l *lexer) stateFn { +func lexString(l *lexer) StateFn { if sr, ok := l.accept([]byte(quotesToken)); ok { l.ignore() @@ -351,7 +354,7 @@ func lexString(l *lexer) stateFn { // lexNumber scans a number: decimal and float. This isn't a perfect number scanner // for instance it accepts "." and "0x0.2" and "089" - but when it's wrong the input // is invalid and the parser (via strconv) should notice. -func lexNumber(l *lexer) stateFn { +func lexNumber(l *lexer) StateFn { if !l.scanNumber() { return l.errorf("bad number syntax: %q", l.input[l.start:l.pos]) } @@ -359,6 +362,7 @@ func lexNumber(l *lexer) stateFn { return lexRoot } +// scanNumber scans a number: decimal and float. func (l *lexer) scanNumber() bool { // Optional leading sign. l.accept(signsToken) @@ -391,14 +395,17 @@ func isAlphaNumeric(r rune) bool { return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) } +// equals reports whether b is equal to val. func equals(b, val []byte) bool { return bytes.EqualFold(b, val) } +// contains reports whether b contains any of the chars. func contains(b []byte, chars string) bool { return bytes.ContainsAny(b, chars) } +// lowercase lowercases the bytes in b. func lowercase(b []byte) { for i := 0; i < len(b); i++ { if b[i] >= 'A' && b[i] <= 'Z' { @@ -407,6 +414,7 @@ func lowercase(b []byte) { } } +// String returns a string representation of the item. func (i item) String() string { var v string diff --git a/core/internal/graph/utils.go b/core/internal/graph/utils.go index 50470266..6b3f78cc 100644 --- a/core/internal/graph/utils.go +++ b/core/internal/graph/utils.go @@ -13,6 +13,7 @@ type FPInfo struct { Name string } +// FastParse parses the query and returns the operation type and name func FastParse(gql string) (h FPInfo, err error) { if gql == "" { return h, errors.New("query missing or empty") @@ -20,6 +21,7 @@ func FastParse(gql string) (h FPInfo, err error) { return fastParse(strings.NewReader(gql)) } +// FastParseBytes parses the query and returns the operation type and name func FastParseBytes(gql []byte) (h FPInfo, err error) { if len(gql) == 0 { return h, errors.New("query missing or empty") @@ -27,6 +29,7 @@ func FastParseBytes(gql []byte) (h FPInfo, err error) { return fastParse(bytes.NewReader(gql)) } +// fastParse parses the query and returns the operation type and name func fastParse(r io.Reader) (h FPInfo, err error) { var s scanner.Scanner s.Init(r) diff --git a/core/internal/sdata/dwg.go b/core/internal/sdata/dwg.go index 9aa208e8..feb20fb6 100644 --- a/core/internal/sdata/dwg.go +++ b/core/internal/sdata/dwg.go @@ -14,6 +14,7 @@ var ( ErrThoughNodeNotFound = errors.New("though node not found") ) +// TEdge represents a table edge for the graph type TEdge struct { From, To, Weight int32 @@ -24,32 +25,36 @@ type TEdge struct { name string } +// addNode adds a table node to the graph func (s *DBSchema) addNode(t DBTable) int32 { s.tables = append(s.tables, t) - n := s.rg.AddNode() + n := s.relationshipGraph.AddNode() s.tindex[(t.Schema + ":" + t.Name)] = nodeInfo{n} return n } +// addAliases adds table aliases to the graph func (s *DBSchema) addAliases(t DBTable, nodeID int32, aliases []string) { for _, al := range aliases { s.tindex[(t.Schema + ":" + al)] = nodeInfo{nodeID} - s.ai[al] = nodeInfo{nodeID} + s.tableAliasIndex[al] = nodeInfo{nodeID} } } +// GetAliases returns a map of table aliases func (s *DBSchema) GetAliases() map[string]DBTable { ts := make(map[string]DBTable) - for name, n := range s.ai { + for name, n := range s.tableAliasIndex { ts[name] = s.tables[int(n.nodeID)] } return ts } +// IsAlias checks if a table is an alias func (s *DBSchema) IsAlias(name string) bool { - _, ok := s.ai[name] + _, ok := s.tableAliasIndex[name] return ok } @@ -145,11 +150,11 @@ func (s *DBSchema) addToGraph( return err } - if err := s.rg.UpdateEdge(ln, rn, edgeID1, edgeID2); err != nil { + if err := s.relationshipGraph.UpdateEdge(ln, rn, edgeID1, edgeID2); err != nil { return err } - if err := s.rg.UpdateEdge(rn, ln, edgeID2, edgeID1); err != nil { + if err := s.relationshipGraph.UpdateEdge(rn, ln, edgeID2, edgeID1); err != nil { return err } @@ -166,10 +171,11 @@ func (s *DBSchema) addToGraph( return nil } +// addEdge creates a relationship between two tables func (s *DBSchema) addEdge(name string, edge TEdge, inSchema bool, ) (int32, error) { // add edge to graph - edgeID, err := s.rg.AddEdge(edge.From, edge.To, + edgeID, err := s.relationshipGraph.AddEdge(edge.From, edge.To, edge.Weight, edge.CName) if err != nil { return -1, err @@ -181,13 +187,14 @@ func (s *DBSchema) addEdge(name string, edge TEdge, inSchema bool, if inSchema { edge.name = name } - s.ae[edgeID] = edge + s.allEdges[edgeID] = edge return edgeID, nil } +// addEdgeInfo adds edge info to the index func (s *DBSchema) addEdgeInfo(k string, ei edgeInfo) { - if eiList, ok := s.ei[k]; ok { + if eiList, ok := s.edgesIndex[k]; ok { for i, v := range eiList { if v.nodeID != ei.nodeID { continue @@ -198,13 +205,14 @@ func (s *DBSchema) addEdgeInfo(k string, ei edgeInfo) { } } edgeIDs := append(v.edgeIDs, ei.edgeIDs[0]) - s.ei[k][i].edgeIDs = edgeIDs + s.edgesIndex[k][i].edgeIDs = edgeIDs return } } - s.ei[k] = append(s.ei[k], ei) + s.edgesIndex[k] = append(s.edgesIndex[k], ei) } +// Find returns a table by schema and name func (s *DBSchema) Find(schema, name string) (DBTable, error) { var t DBTable @@ -220,6 +228,7 @@ func (s *DBSchema) Find(schema, name string) (DBTable, error) { return s.tables[v.nodeID], nil } +// TPath represents a table path type TPath struct { Rel RelType LT DBTable @@ -228,13 +237,14 @@ type TPath struct { RC DBColumn } +// FindPath returns a path between two tables func (s *DBSchema) FindPath(from, to, through string) ([]TPath, error) { - fl, ok := s.ei[from] + fl, ok := s.edgesIndex[from] if !ok { return nil, ErrFromEdgeNotFound } - tl, ok := s.ei[to] + tl, ok := s.edgesIndex[to] if !ok { return nil, ErrToEdgeNotFound } @@ -250,7 +260,7 @@ func (s *DBSchema) FindPath(from, to, through string) ([]TPath, error) { path := []TPath{} for _, eid := range res.edges { - edge := s.ae[eid] + edge := s.allEdges[eid] path = append(path, TPath{ Rel: edge.Type, LT: edge.LT, @@ -265,11 +275,13 @@ func (s *DBSchema) FindPath(from, to, through string) ([]TPath, error) { return path, nil } +// graphResult represents a graph result type graphResult struct { from, to edgeInfo edges []int32 } +// between finds a path between two tables func (s *DBSchema) between(from, to []edgeInfo, through string) (res graphResult, err error) { // TODO: picking a path // 1. first look for a direct edge to other table @@ -288,13 +300,14 @@ func (s *DBSchema) between(from, to []edgeInfo, through string) (res graphResult return res, ErrPathNotFound } +// pickPath picks a path between two tables func (s *DBSchema) pickPath(from, to edgeInfo, through string) (res graphResult, err error) { res.from = from res.to = to fn := from.nodeID tn := to.nodeID - paths := s.rg.AllPaths(fn, tn) + paths := s.relationshipGraph.AllPaths(fn, tn) if through != "" { paths, err = s.pickThroughPath(paths, through) @@ -313,6 +326,7 @@ func (s *DBSchema) pickPath(from, to edgeInfo, through string) (res graphResult, return res, ErrPathNotFound } +// pickEdges picks edges between two tables func (s *DBSchema) pickEdges(path []int32, from, to edgeInfo) (edges []int32, allFound bool) { pathLen := len(path) peID := int32(-2) // must be -2 so does not match default -1 @@ -320,7 +334,7 @@ func (s *DBSchema) pickEdges(path []int32, from, to edgeInfo) (edges []int32, al for i := 1; i < pathLen; i++ { fn := path[i-1] tn := path[i] - lines := s.rg.GetEdges(fn, tn) + lines := s.relationshipGraph.GetEdges(fn, tn) // s.PrintLines(lines) @@ -354,6 +368,7 @@ func (s *DBSchema) pickEdges(path []int32, from, to edgeInfo) (edges []int32, al return } +// pickThroughPath picks a path through a node func (s *DBSchema) pickThroughPath(paths [][]int32, through string) ([][]int32, error) { var npaths [][]int32 @@ -376,6 +391,7 @@ func (s *DBSchema) pickThroughPath(paths [][]int32, through string) ([][]int32, return npaths, nil } +// pickLine picks a line between two tables func pickLine(lines []util.Edge, ei edgeInfo, peID int32) *util.Edge { for _, v := range lines { for _, eid := range ei.edgeIDs { @@ -387,6 +403,7 @@ func pickLine(lines []util.Edge, ei edgeInfo, peID int32) *util.Edge { return nil } +// PathToRel converts a table path to a relationship func PathToRel(p TPath) DBRel { return DBRel{ Type: p.Rel, @@ -395,6 +412,7 @@ func PathToRel(p TPath) DBRel { } } +// minWeightedLine returns the line with the minimum weight func minWeightedLine(lines []util.Edge, peID int32) *util.Edge { var min int32 = 100 var line *util.Edge @@ -413,9 +431,10 @@ func minWeightedLine(lines []util.Edge, peID int32) *util.Edge { return line } +// PrintLines prints the graph lines func (s *DBSchema) PrintLines(lines []util.Edge) { for _, v := range lines { - e := s.ae[v.ID] + e := s.allEdges[v.ID] f := s.tables[e.From] t := s.tables[e.To] @@ -425,6 +444,7 @@ func (s *DBSchema) PrintLines(lines []util.Edge) { fmt.Println("---") } +// PrintEdgeInfo prints edge info func (s *DBSchema) PrintEdgeInfo(e edgeInfo) { t := s.tables[e.nodeID] fmt.Printf("-- EdgeInfo %s %+v\n", t.Name, e.edgeIDs) @@ -434,6 +454,7 @@ func (s *DBSchema) PrintEdgeInfo(e edgeInfo) { // } } +// String returns a string representation of a table path func (tp *TPath) String() string { return fmt.Sprintf("(%s) %s ==> %s ==> (%s) %s", tp.LT.String(), tp.LC.String(), diff --git a/core/internal/sdata/schema.go b/core/internal/sdata/schema.go index 6e14bbf1..31c9fac2 100644 --- a/core/internal/sdata/schema.go +++ b/core/internal/sdata/schema.go @@ -19,18 +19,18 @@ type nodeInfo struct { } type DBSchema struct { - typ string // db type - ver int // db version - schema string // db schema - name string // db name - tables []DBTable // tables - vt map[string]VirtualTable // for polymorphic relationships - fm map[string]DBFunction // db functions - tindex map[string]nodeInfo // table index - ai map[string]nodeInfo // table alias index - ei map[string][]edgeInfo // edges index - ae map[int32]TEdge // all edges - rg *util.Graph // relationship graph + dbType string // db type + version int // db version + schema string // db schema + name string // db name + tables []DBTable // tables + virtualTables map[string]VirtualTable // for polymorphic relationships + dbFunctions map[string]DBFunction // db functions + tindex map[string]nodeInfo // table index + tableAliasIndex map[string]nodeInfo // table alias index + edgesIndex map[string][]edgeInfo // edges index + allEdges map[int32]TEdge // all edges + relationshipGraph *util.Graph // relationship graph } type RelType int @@ -46,39 +46,43 @@ const ( RelSkip ) +// DBRelLeft represents database information type DBRelLeft struct { Ti DBTable Col DBColumn } +// DBRelRight represents a database relationship type DBRelRight struct { VTable string Ti DBTable Col DBColumn } +// DBRel represents a database relationship type DBRel struct { Type RelType Left DBRelLeft Right DBRelRight } +// NewDBSchema creates a new database schema func NewDBSchema( info *DBInfo, aliases map[string][]string, ) (*DBSchema, error) { schema := &DBSchema{ - typ: info.Type, - ver: info.Version, - schema: info.Schema, - name: info.Name, - vt: make(map[string]VirtualTable), - fm: make(map[string]DBFunction), - tindex: make(map[string]nodeInfo), - ai: make(map[string]nodeInfo), - ei: make(map[string][]edgeInfo), - ae: make(map[int32]TEdge), - rg: util.NewGraph(), + dbType: info.Type, + version: info.Version, + schema: info.Schema, + name: info.Name, + virtualTables: make(map[string]VirtualTable), + dbFunctions: make(map[string]DBFunction), + tindex: make(map[string]nodeInfo), + tableAliasIndex: make(map[string]nodeInfo), + edgesIndex: make(map[string][]edgeInfo), + allEdges: make(map[int32]TEdge), + relationshipGraph: util.NewGraph(), } for _, t := range info.Tables { @@ -102,11 +106,11 @@ func NewDBSchema( // add aliases to edge index by duplicating for t, al := range aliases { for _, alias := range al { - if _, ok := schema.ei[alias]; ok { + if _, ok := schema.edgesIndex[alias]; ok { continue } - if e, ok := schema.ei[t]; ok { - schema.ei[alias] = e + if e, ok := schema.edgesIndex[t]; ok { + schema.edgesIndex[alias] = e } } } @@ -127,13 +131,14 @@ func NewDBSchema( // don't include functions that return records // as those are considered selector functions if f.Type != "record" { - schema.fm[f.Name] = info.Functions[k] + schema.dbFunctions[f.Name] = info.Functions[k] } } return schema, nil } +// addRels adds relationships to the schema func (s *DBSchema) addRels(t DBTable) error { var err error switch t.Type { @@ -152,6 +157,7 @@ func (s *DBSchema) addRels(t DBTable) error { return s.addColumnRels(t) } +// addJsonRel adds a json relationship to the schema func (s *DBSchema) addJsonRel(t DBTable) error { st, err := s.Find(t.SecondaryCol.Schema, t.SecondaryCol.Table) if err != nil { @@ -166,6 +172,7 @@ func (s *DBSchema) addJsonRel(t DBTable) error { return s.addToGraph(t, t.PrimaryCol, st, sc, RelEmbedded) } +// addPolymorphicRel adds a polymorphic relationship to the schema func (s *DBSchema) addPolymorphicRel(t DBTable) error { pt, err := s.Find(t.PrimaryCol.FKeySchema, t.PrimaryCol.FKeyTable) if err != nil { @@ -185,6 +192,7 @@ func (s *DBSchema) addPolymorphicRel(t DBTable) error { return s.addToGraph(t, t.PrimaryCol, pt, pc, RelPolymorphic) } +// addRemoteRel adds a remote relationship to the schema func (s *DBSchema) addRemoteRel(t DBTable) error { pt, err := s.Find(t.PrimaryCol.FKeySchema, t.PrimaryCol.FKeyTable) if err != nil { @@ -199,6 +207,7 @@ func (s *DBSchema) addRemoteRel(t DBTable) error { return s.addToGraph(t, t.PrimaryCol, pt, pc, RelRemote) } +// addColumnRels adds column relationships to the schema func (s *DBSchema) addColumnRels(t DBTable) error { var err error @@ -244,8 +253,9 @@ func (s *DBSchema) addColumnRels(t DBTable) error { return nil } +// addVirtual adds a virtual table to the schema func (s *DBSchema) addVirtual(vt VirtualTable) error { - s.vt[vt.Name] = vt + s.virtualTables[vt.Name] = vt for _, t := range s.tables { idCol, ok := t.getColumn(vt.IDColumn) @@ -298,22 +308,25 @@ func (s *DBSchema) addVirtual(vt VirtualTable) error { return nil } +// GetTables returns a table from the schema func (s *DBSchema) GetTables() []DBTable { return s.tables } +// RelNode represents a relationship node type RelNode struct { Name string Type RelType Table DBTable } +// GetFirstDegree returns the first degree relationships of a table func (s *DBSchema) GetFirstDegree(t DBTable) (items []RelNode, err error) { currNode, ok := s.tindex[(t.Schema + ":" + t.Name)] if !ok { return nil, fmt.Errorf("table not found: %s", t.String()) } - relatedNodes := s.rg.Connections(currNode.nodeID) + relatedNodes := s.relationshipGraph.Connections(currNode.nodeID) for _, id := range relatedNodes { v := s.getRelNodes(id, currNode.nodeID) items = append(items, v...) @@ -321,15 +334,16 @@ func (s *DBSchema) GetFirstDegree(t DBTable) (items []RelNode, err error) { return } +// GetSecondDegree returns the second degree relationships of a table func (s *DBSchema) GetSecondDegree(t DBTable) (items []RelNode, err error) { currNode, ok := s.tindex[(t.Schema + ":" + t.Name)] if !ok { return nil, fmt.Errorf("table not found: %s", t.String()) } - relatedNodes1 := s.rg.Connections(currNode.nodeID) + relatedNodes1 := s.relationshipGraph.Connections(currNode.nodeID) for _, id := range relatedNodes1 { - relatedNodes2 := s.rg.Connections(id) + relatedNodes2 := s.relationshipGraph.Connections(id) for _, id1 := range relatedNodes2 { v := s.getRelNodes(id1, id) items = append(items, v...) @@ -338,10 +352,11 @@ func (s *DBSchema) GetSecondDegree(t DBTable) (items []RelNode, err error) { return } +// getRelNodes returns the relationship nodes func (s *DBSchema) getRelNodes(fromID, toID int32) (items []RelNode) { - edges := s.rg.GetEdges(fromID, toID) + edges := s.relationshipGraph.GetEdges(fromID, toID) for _, e := range edges { - e1 := s.ae[e.ID] + e1 := s.allEdges[e.ID] if e1.name == "" { continue } @@ -351,6 +366,7 @@ func (s *DBSchema) getRelNodes(fromID, toID int32) (items []RelNode) { return } +// getColumn returns a column from a table func (ti *DBTable) getColumn(name string) (DBColumn, bool) { var c DBColumn if i, ok := ti.colMap[name]; ok { @@ -359,6 +375,7 @@ func (ti *DBTable) getColumn(name string) (DBColumn, bool) { return c, false } +// GetColumn returns a column from a table func (ti *DBTable) GetColumn(name string) (DBColumn, error) { c, ok := ti.getColumn(name) if ok { @@ -367,14 +384,17 @@ func (ti *DBTable) GetColumn(name string) (DBColumn, error) { return c, fmt.Errorf("column: '%s.%s' not found", ti.Name, name) } +// ColumnExists returns true if a column exists in a table func (ti *DBTable) ColumnExists(name string) (DBColumn, bool) { return ti.getColumn(name) } +// GetFunction returns a function from the schema func (s *DBSchema) GetFunctions() map[string]DBFunction { - return s.fm + return s.dbFunctions } +// GetRelName returns the relationship name func GetRelName(colName string) string { cn := colName @@ -397,18 +417,22 @@ func GetRelName(colName string) string { return cn } +// DBType returns the database type func (s *DBSchema) DBType() string { - return s.typ + return s.dbType } +// DBVersion returns the database version func (s *DBSchema) DBVersion() int { - return s.ver + return s.version } +// DBSchema returns the database schema func (s *DBSchema) DBSchema() string { return s.schema } +// DBName returns the database name func (s *DBSchema) DBName() string { return s.name } diff --git a/core/internal/sdata/strings.go b/core/internal/sdata/strings.go index 956ac112..5c45a848 100644 --- a/core/internal/sdata/strings.go +++ b/core/internal/sdata/strings.go @@ -5,10 +5,12 @@ import ( "strings" ) +// String returns a string representation of the DBTable func (ti *DBTable) String() string { return ti.Schema + "." + ti.Name } +// String returns a string representation of the DBColumn func (col DBColumn) String() string { var sb strings.Builder @@ -23,6 +25,7 @@ func (col DBColumn) String() string { return sb.String() } +// String returns a string representation of the DBFunction func (fn DBFunction) String() string { var sb strings.Builder @@ -50,6 +53,7 @@ func (fn DBFunction) String() string { return sb.String() } +// String returns a string representation of the DBRel func (re *DBRel) String() string { return fmt.Sprintf("'%s' --(%s)--> '%s'", re.Left.Col.String(), diff --git a/core/internal/sdata/tables.go b/core/internal/sdata/tables.go index 8be92cb9..cf397b9f 100644 --- a/core/internal/sdata/tables.go +++ b/core/internal/sdata/tables.go @@ -10,6 +10,7 @@ import ( "golang.org/x/sync/errgroup" ) +// DBInfo holds the database schema information type DBInfo struct { Type string Version int @@ -24,6 +25,7 @@ type DBInfo struct { hash int } +// DBTable holds the database table information type DBTable struct { Comment string Schema string @@ -38,6 +40,7 @@ type DBTable struct { colMap map[string]int } +// VirtualTable holds the virtual table information type VirtualTable struct { Name string IDColumn string @@ -45,6 +48,7 @@ type VirtualTable struct { FKeyColumn string } +// GetDBInfo returns the database schema information func GetDBInfo( db *sql.DB, dbType string, @@ -101,6 +105,7 @@ func GetDBInfo( return di, nil } +// NewDBInfo returns a new DBInfo object func NewDBInfo( dbType string, dbVersion int, @@ -176,6 +181,7 @@ func NewDBInfo( return di } +// NewDBTable returns a new DBTable object func NewDBTable(schema, name, _type string, cols []DBColumn) DBTable { ti := DBTable{ Schema: schema, @@ -202,6 +208,7 @@ func NewDBTable(schema, name, _type string, cols []DBColumn) DBTable { return ti } +// AddTable adds a table to the DBInfo object func (di *DBInfo) AddTable(t DBTable) { for i, c := range t.Columns { di.colMap[(c.Schema + ":" + c.Table + ":" + c.Name)] = i @@ -212,6 +219,7 @@ func (di *DBInfo) AddTable(t DBTable) { di.tableMap[(t.Schema + ":" + t.Name)] = i } +// GetTable returns a table from the DBInfo object func (di *DBInfo) GetColumn(schema, table, column string) (*DBColumn, error) { t, err := di.GetTable(schema, table) if err != nil { @@ -226,6 +234,7 @@ func (di *DBInfo) GetColumn(schema, table, column string) (*DBColumn, error) { return &t.Columns[cid], nil } +// GetTable returns a table from the DBInfo object func (di *DBInfo) GetTable(schema, table string) (*DBTable, error) { tid, ok := di.tableMap[(schema + ":" + table)] if !ok { @@ -235,6 +244,7 @@ func (di *DBInfo) GetTable(schema, table string) (*DBTable, error) { return &di.Tables[tid], nil } +// DBColumn returns the column as a string type DBColumn struct { Comment string ID int32 @@ -254,6 +264,7 @@ type DBColumn struct { Schema string } +// DiscoverColumns returns the columns of a table func DiscoverColumns(db *sql.DB, dbtype string, blockList []string) ([]DBColumn, error) { var sqlStmt string @@ -350,6 +361,7 @@ func DiscoverColumns(db *sql.DB, dbtype string, blockList []string) ([]DBColumn, return cols, nil } +// DBFunction holds the database function information type DBFunction struct { Comment string Schema string @@ -360,6 +372,7 @@ type DBFunction struct { Outputs []DBFuncParam } +// DBFuncParam holds the database function parameter information type DBFuncParam struct { ID int Name string @@ -367,6 +380,7 @@ type DBFuncParam struct { Array bool } +// DiscoverFunctions returns the functions of a database func DiscoverFunctions(db *sql.DB, dbtype string, blockList []string) ([]DBFunction, error) { var sqlStmt string @@ -423,6 +437,7 @@ func DiscoverFunctions(db *sql.DB, dbtype string, blockList []string) ([]DBFunct return funcs, nil } +// GetInput returns the input of a function func (fn *DBFunction) GetInput(name string) (ret DBFuncParam, err error) { for _, in := range fn.Inputs { if in.Name == name { @@ -432,10 +447,12 @@ func (fn *DBFunction) GetInput(name string) (ret DBFuncParam, err error) { return ret, fmt.Errorf("function input '%s' not found", name) } +// Hash returns the hash of the DBInfo object func (di *DBInfo) Hash() int { return di.hash } +// isInList checks if a value is in a list func isInList(val string, s []string) bool { for _, v := range s { regex := fmt.Sprintf("^%s$", v) diff --git a/core/internal/util/graph.go b/core/internal/util/graph.go index 67a9e7c7..b2a9a0c1 100644 --- a/core/internal/util/graph.go +++ b/core/internal/util/graph.go @@ -17,16 +17,19 @@ type Graph struct { graph [][]int32 } +// Create a new graph func NewGraph() *Graph { return &Graph{edges: make(map[[2]int32][]Edge)} } +// AddNode adds a new node to the graph func (g *Graph) AddNode() int32 { id := int32(len(g.graph)) g.graph = append(g.graph, []int32{}) return id } +// AddEdge adds a new edge to the graph func (g *Graph) AddEdge(from, to, weight int32, name string) (int32, error) { nl := int32(len(g.graph)) if from >= nl { @@ -55,6 +58,7 @@ func (g *Graph) AddEdge(from, to, weight int32, name string) (int32, error) { return id, nil } +// UpdateEdge updates the edge with the given ID func (g *Graph) UpdateEdge( from, to, edgeID, oppEdgeID int32, ) error { @@ -74,10 +78,12 @@ func (g *Graph) UpdateEdge( return fmt.Errorf("edge not found: %d", edgeID) } +// GetEdges returns all edges between the two nodes func (g *Graph) GetEdges(from, to int32) []Edge { return g.edges[[2]int32{from, to}] } +// AllPaths returns all paths between two nodes func (g *Graph) AllPaths(from, to int32) [][]int32 { var paths [][]int32 var limit int @@ -135,10 +141,12 @@ func (g *Graph) AllPaths(from, to int32) [][]int32 { return paths } +// Connections returns all connections for a given node func (g *Graph) Connections(n int32) []int32 { return g.graph[n] } +// equals checks if two slices are equal func equals(a, b []int32) bool { if len(a) != len(b) { return false diff --git a/core/internal/util/graph_test.go b/core/internal/util/graph_test.go index cbb8b417..caf1914d 100644 --- a/core/internal/util/graph_test.go +++ b/core/internal/util/graph_test.go @@ -53,7 +53,7 @@ func TestGraph1(t *testing.T) { }) edges := g.GetEdges(b, b) - assert.Equals(t, edges, []util.Edge{{13, 2, "test"}}) + assert.Equals(t, edges, []util.Edge{{ID: 13, OppID: 2, Weight: 0, Name: "test"}}) } /* diff --git a/core/intro.go b/core/intro.go index 713fb070..23a2a565 100644 --- a/core/intro.go +++ b/core/intro.go @@ -189,7 +189,7 @@ var stdTypes = []FullType{ }, } -type Introspection struct { +type introspection struct { schema *sdata.DBSchema camelCase bool types map[string]FullType @@ -198,10 +198,11 @@ type Introspection struct { result IntroResult } -func (gj *graphjin) introQuery() (result json.RawMessage, err error) { +// introQuery returns the introspection query result +func (gj *GraphjinEngine) introQuery() (result json.RawMessage, err error) { // Initialize the introscpection object - in := Introspection{ + in := introspection{ schema: gj.schema, camelCase: gj.conf.EnableCamelcase, types: make(map[string]FullType), @@ -216,6 +217,7 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { MutationType: &ShortFullType{Name: "Mutation"}, } + // Add the standard types // Add the standard types for _, v := range stdTypes { in.addType(v) @@ -228,6 +230,11 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { in.addExpTypes(v, "Int", newTypeRef("", "Int", nil)) in.addExpTypes(v, "Boolean", newTypeRef("", "Boolean", nil)) in.addExpTypes(v, "Float", newTypeRef("", "Float", nil)) + in.addExpTypes(v, "ID", newTypeRef("", "ID", nil)) + in.addExpTypes(v, "String", newTypeRef("", "String", nil)) + in.addExpTypes(v, "Int", newTypeRef("", "Int", nil)) + in.addExpTypes(v, "Boolean", newTypeRef("", "Boolean", nil)) + in.addExpTypes(v, "Float", newTypeRef("", "Float", nil)) // ListExpression Types v = append(expAll, expList...) @@ -235,14 +242,21 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { in.addExpTypes(v, "IntList", newTypeRef("", "Int", nil)) in.addExpTypes(v, "BooleanList", newTypeRef("", "Boolean", nil)) in.addExpTypes(v, "FloatList", newTypeRef("", "Float", nil)) + in.addExpTypes(v, "StringList", newTypeRef("", "String", nil)) + in.addExpTypes(v, "IntList", newTypeRef("", "Int", nil)) + in.addExpTypes(v, "BooleanList", newTypeRef("", "Boolean", nil)) + in.addExpTypes(v, "FloatList", newTypeRef("", "Float", nil)) v = append(expAll, expJSON...) in.addExpTypes(v, "JSON", newTypeRef("", "String", nil)) + in.addExpTypes(v, "JSON", newTypeRef("", "String", nil)) + // Add the roles // Add the roles in.addRolesEnumType(gj.roles) in.addTablesEnumType() + // Get all the alias and add to the schema // Get all the alias and add to the schema for alias, t := range in.schema.GetAliases() { if err = in.addTable(t, alias); err != nil { @@ -250,6 +264,7 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { } } + // Get all the tables and add to the schema // Get all the tables and add to the schema for _, t := range in.schema.GetTables() { if err = in.addTable(t, ""); err != nil { @@ -257,12 +272,14 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { } } + // Add the directives // Add the directives for _, dt := range dirTypes { in.addDirType(dt) } in.addDirValidateType() + // Add the types to the schema // Add the types to the schema for _, v := range in.types { in.result.Schema.Types = append(in.result.Schema.Types, v) @@ -272,7 +289,8 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { return } -func (in *Introspection) addTable(table sdata.DBTable, alias string) (err error) { +// addTable adds a table to the introspection schema +func (in *introspection) addTable(table sdata.DBTable, alias string) (err error) { if table.Blocked || len(table.Columns) == 0 { return } @@ -309,7 +327,8 @@ func (in *Introspection) addTable(table sdata.DBTable, alias string) (err error) return } -func (in *Introspection) addTypeTo(op string, ft FullType) { +// addTypeTo adds a type to the introspection schema +func (in *introspection) addTypeTo(op string, ft FullType) { qt := in.types[op] qt.Fields = append(qt.Fields, FieldObject{ Name: ft.Name, @@ -320,7 +339,8 @@ func (in *Introspection) addTypeTo(op string, ft FullType) { in.types[op] = qt } -func (in *Introspection) getName(name string) string { +// getName returns the name of the type +func (in *introspection) getName(name string) string { if in.camelCase { return util.ToCamel(name) } else { @@ -328,7 +348,8 @@ func (in *Introspection) getName(name string) string { } } -func (in *Introspection) addExpTypes(exps []exp, name string, rt *TypeRef) { +// addExpTypes adds the expression types to the introspection schema +func (in *introspection) addExpTypes(exps []exp, name string, rt *TypeRef) { ft := FullType{ Kind: KIND_INPUT_OBJ, Name: (name + SUFFIX_EXP), @@ -350,11 +371,13 @@ func (in *Introspection) addExpTypes(exps []exp, name string, rt *TypeRef) { in.addType(ft) } -func (in *Introspection) addTableType(t sdata.DBTable, alias string) (ft FullType, err error) { +// addTableType adds a table type to the introspection schema +func (in *introspection) addTableType(t sdata.DBTable, alias string) (ft FullType, err error) { return in.addTableTypeWithDepth(t, alias, 0) } -func (in *Introspection) addTableTypeWithDepth( +// addTableTypeWithDepth adds a table type with depth to the introspection schema +func (in *introspection) addTableTypeWithDepth( table sdata.DBTable, alias string, depth int, ) (ft FullType, err error) { ft = FullType{ @@ -463,7 +486,8 @@ func (in *Introspection) addTableTypeWithDepth( return } -func (in *Introspection) addColumnsEnumType(t sdata.DBTable) (err error) { +// addColumnsEnumType adds an enum type for the columns of the table +func (in *introspection) addColumnsEnumType(t sdata.DBTable) (err error) { tableName := in.getName(t.Name) ft := FullType{ Kind: KIND_ENUM, @@ -483,7 +507,8 @@ func (in *Introspection) addColumnsEnumType(t sdata.DBTable) (err error) { return } -func (in *Introspection) addTablesEnumType() { +// addTablesEnumType adds an enum type for the tables +func (in *introspection) addTablesEnumType() { ft := FullType{ Kind: KIND_ENUM, Name: ("tables" + SUFFIX_ENUM), @@ -501,7 +526,8 @@ func (in *Introspection) addTablesEnumType() { in.addType(ft) } -func (in *Introspection) addRolesEnumType(roles map[string]*Role) { +// addRolesEnumType adds an enum type for the roles +func (in *introspection) addRolesEnumType(roles map[string]*Role) { ft := FullType{ Kind: KIND_ENUM, Name: ("roles" + SUFFIX_ENUM), @@ -520,7 +546,8 @@ func (in *Introspection) addRolesEnumType(roles map[string]*Role) { in.addType(ft) } -func (in *Introspection) addOrderByType(t sdata.DBTable, ft *FullType) { +// addOrderByType adds an order by type to the introspection schema +func (in *introspection) addOrderByType(t sdata.DBTable, ft *FullType) { ty := FullType{ Kind: KIND_INPUT_OBJ, Name: (t.Name + SUFFIX_ORDER_BY), @@ -539,7 +566,8 @@ func (in *Introspection) addOrderByType(t sdata.DBTable, ft *FullType) { ft.addArg("orderBy", newTypeRef("", (t.Name+SUFFIX_ORDER_BY), nil)) } -func (in *Introspection) addWhereType(table sdata.DBTable, ft *FullType) { +// addWhereType adds a where type to the introspection schema +func (in *introspection) addWhereType(table sdata.DBTable, ft *FullType) { tablename := (table.Name + SUFFIX_WHERE) ty := FullType{ Kind: "INPUT_OBJECT", @@ -570,7 +598,7 @@ func (in *Introspection) addWhereType(table sdata.DBTable, ft *FullType) { ft.addArg("where", newTypeRef("", ty.Name, nil)) } -func (in *Introspection) addInputType(table sdata.DBTable, ft FullType) (retFT FullType, err error) { +func (in *introspection) addInputType(table sdata.DBTable, ft FullType) (retFT FullType, err error) { // upsert ty := FullType{ Kind: "INPUT_OBJECT", @@ -664,7 +692,8 @@ func (in *Introspection) addInputType(table sdata.DBTable, ft FullType) (retFT F return } -func (in *Introspection) addTableArgsType(table sdata.DBTable, ft *FullType) { +// addTableArgsType adds the table arguments type to the introspection schema +func (in *introspection) addTableArgsType(table sdata.DBTable, ft *FullType) { if table.Type != "function" { return } @@ -673,7 +702,8 @@ func (in *Introspection) addTableArgsType(table sdata.DBTable, ft *FullType) { ft.addArg("args", newTypeRef("", ty.Name, nil)) } -func (in *Introspection) addArgsType(table sdata.DBTable, fn sdata.DBFunction) (ft FullType) { +// addArgsType adds the arguments type to the introspection schema +func (in *introspection) addArgsType(table sdata.DBTable, fn sdata.DBFunction) (ft FullType) { ft = FullType{ Kind: "INPUT_OBJECT", Name: (table.Name + fn.Name + SUFFIX_ARGS), @@ -705,7 +735,8 @@ func (in *Introspection) addArgsType(table sdata.DBTable, fn sdata.DBFunction) ( return } -func (in *Introspection) getColumnField(column sdata.DBColumn) (field FieldObject, err error) { +// getColumnField returns the field object for the given column +func (in *introspection) getColumnField(column sdata.DBColumn) (field FieldObject, err error) { field.Args = []InputValue{} field.Name = in.getName(column.Name) typeValue := newTypeRef("", "String", nil) @@ -736,7 +767,8 @@ func (in *Introspection) getColumnField(column sdata.DBColumn) (field FieldObjec return } -func (in *Introspection) getFunctionField(t sdata.DBTable, fn sdata.DBFunction) (f FieldObject) { +// getFunctionField returns the field object for the given function +func (in *introspection) getFunctionField(t sdata.DBTable, fn sdata.DBFunction) (f FieldObject) { f.Name = in.getName(fn.Name) f.Args = []InputValue{} ty, list := getType(fn.Type) @@ -761,7 +793,8 @@ func (in *Introspection) getFunctionField(t sdata.DBTable, fn sdata.DBFunction) return } -func (in *Introspection) getTableField(relNode sdata.RelNode) ( +// getTableField returns the field object for the given table +func (in *introspection) getTableField(relNode sdata.RelNode) ( f FieldObject, skip bool, err error, ) { f.Args = []InputValue{} @@ -785,7 +818,8 @@ func (in *Introspection) getTableField(relNode sdata.RelNode) ( return } -func (in *Introspection) addDirType(dt dir) { +// addDirType adds a directive type to the introspection schema +func (in *introspection) addDirType(dt dir) { d := DirectiveType{ Name: dt.name, Description: dt.desc, @@ -805,7 +839,8 @@ func (in *Introspection) addDirType(dt dir) { in.result.Schema.Directives = append(in.result.Schema.Directives, d) } -func (in *Introspection) addDirValidateType() { +// addDirValidateType adds a validate directive type to the introspection schema +func (in *introspection) addDirValidateType() { ft := FullType{ Kind: KIND_ENUM, Name: ("validateFormat" + SUFFIX_ENUM), @@ -848,6 +883,7 @@ func (in *Introspection) addDirValidateType() { in.result.Schema.Directives = append(in.result.Schema.Directives, d) } +// addArg adds an argument to the full type func (ft *FullType) addArg(name string, tr *TypeRef) { ft.InputFields = append(ft.InputFields, InputValue{ Name: name, @@ -855,6 +891,7 @@ func (ft *FullType) addArg(name string, tr *TypeRef) { }) } +// addOrReplaceArg adds or replaces an argument to the full type func (ft *FullType) addOrReplaceArg(name string, tr *TypeRef) { for i, a := range ft.InputFields { if a.Name == name { @@ -868,7 +905,8 @@ func (ft *FullType) addOrReplaceArg(name string, tr *TypeRef) { }) } -func (in *Introspection) addType(ft FullType) { +// addType adds a type to the introspection schema +func (in *introspection) addType(ft FullType) { in.types[ft.Name] = ft } diff --git a/core/osfs.go b/core/osfs.go index a8b576d6..9a8ed1c2 100644 --- a/core/osfs.go +++ b/core/osfs.go @@ -8,17 +8,20 @@ import ( ) type osFS struct { - bp string + basePath string } -func NewOsFS(basePath string) *osFS { return &osFS{bp: basePath} } +// NewOsFS creates a new OSFS instance +func NewOsFS(basePath string) *osFS { return &osFS{basePath: basePath} } +// Get returns the file content func (f *osFS) Get(path string) ([]byte, error) { - return os.ReadFile(filepath.Join(f.bp, path)) + return os.ReadFile(filepath.Join(f.basePath, path)) } +// Put writes the data to the file func (f *osFS) Put(path string, data []byte) (err error) { - path = filepath.Join(f.bp, path) + path = filepath.Join(f.basePath, path) dir := filepath.Dir(path) ok, err := f.exists(dir) @@ -32,12 +35,14 @@ func (f *osFS) Put(path string, data []byte) (err error) { return os.WriteFile(path, data, os.ModePerm) } +// Exists checks if the file exists func (f *osFS) Exists(path string) (ok bool, err error) { - path = filepath.Join(f.bp, path) + path = filepath.Join(f.basePath, path) ok, err = f.exists(path) return } +// Remove deletes the file func (f *osFS) exists(path string) (ok bool, err error) { if _, err = os.Stat(path); err == nil { ok = true diff --git a/core/remote_api.go b/core/remote_api.go index cee72c8c..73ef723f 100644 --- a/core/remote_api.go +++ b/core/remote_api.go @@ -11,7 +11,7 @@ import ( "github.com/dosco/graphjin/core/v3/internal/jsn" ) -// RemoteAPI struct defines a remote API endpoint +// remoteAPI struct defines a remote API endpoint type remoteAPI struct { httpClient *http.Client URL string @@ -26,6 +26,7 @@ type remoteHdrs struct { Value string } +// newRemoteAPI creates a new remote API endpoint func newRemoteAPI(v map[string]interface{}, httpClient *http.Client) (*remoteAPI, error) { ra := remoteAPI{ httpClient: httpClient, @@ -50,6 +51,7 @@ func newRemoteAPI(v map[string]interface{}, httpClient *http.Client) (*remoteAPI return &ra, nil } +// Resolve function resolves a remote API request func (r *remoteAPI) Resolve(c context.Context, rr ResolverReq) ([]byte, error) { uri := strings.ReplaceAll(r.URL, "$id", rr.ID) diff --git a/core/remote_join.go b/core/remote_join.go index d1abdb23..8013e94c 100644 --- a/core/remote_join.go +++ b/core/remote_join.go @@ -11,6 +11,7 @@ import ( "github.com/dosco/graphjin/core/v3/internal/qcode" ) +// execRemoteJoin fetches remote data for the marked insertion points func (s *gstate) execRemoteJoin(c context.Context) (err error) { // fetch the field name used within the db response json // that are used to mark insertion points and the mapping between @@ -41,6 +42,7 @@ func (s *gstate) execRemoteJoin(c context.Context) (err error) { return } +// resolveRemotes fetches remote data for the marked insertion points func (s *gstate) resolveRemotes( ctx context.Context, from []jsn.Field, @@ -85,7 +87,7 @@ func (s *gstate) resolveRemotes( ctx1, span := s.gj.spanStart(ctx, "Execute Remote Request") b, err := r.Fn.Resolve(ctx1, ResolverReq{ - ID: string(id), Sel: sel, Log: s.gj.log, ReqConfig: s.r.rc, + ID: string(id), Sel: sel, Log: s.gj.log, RequestConfig: s.r.requestconfig, }) if err != nil { cerr = fmt.Errorf("%s: %s", sel.Table, err) @@ -121,6 +123,7 @@ func (s *gstate) resolveRemotes( return to, cerr } +// parentFieldIds fetches the field name used within the db response json func (s *gstate) parentFieldIds() ([][]byte, map[string]*qcode.Select, error) { selects := s.cs.st.qc.Selects remotes := s.cs.st.qc.Remotes @@ -148,6 +151,7 @@ func (s *gstate) parentFieldIds() ([][]byte, map[string]*qcode.Select, error) { return fm, sm, nil } +// fieldsToList converts a list of qcode.Field to a list of strings func fieldsToList(fields []qcode.Field) []string { var f []string diff --git a/core/resolve.go b/core/resolve.go index 7ae844e7..678cc402 100644 --- a/core/resolve.go +++ b/core/resolve.go @@ -15,7 +15,8 @@ type resItem struct { Fn Resolver } -func (gj *graphjin) newRTMap() map[string]ResolverFn { +// newRTMap returns a map of resolver functions +func (gj *GraphjinEngine) newRTMap() map[string]ResolverFn { return map[string]ResolverFn{ "remote_api": func(v ResolverProps) (Resolver, error) { return newRemoteAPI(v, gj.trace.NewHTTPClient()) @@ -23,7 +24,8 @@ func (gj *graphjin) newRTMap() map[string]ResolverFn { } } -func (gj *graphjin) initResolvers() error { +// initResolvers initializes the resolvers +func (gj *GraphjinEngine) initResolvers() error { gj.rmap = make(map[string]resItem) if gj.rtmap == nil { @@ -42,7 +44,8 @@ func (gj *graphjin) initResolvers() error { return nil } -func (gj *graphjin) initRemote( +// initRemote initializes the remote resolver +func (gj *GraphjinEngine) initRemote( rc ResolverConfig, rtmap map[string]ResolverFn, ) error { // Defines the table column to be used as an id in the diff --git a/core/rolestmt.go b/core/rolestmt.go index 522d8b5e..61a59aa2 100644 --- a/core/rolestmt.go +++ b/core/rolestmt.go @@ -8,7 +8,7 @@ import ( ) // nolint:errcheck -func (gj *graphjin) prepareRoleStmt() error { +func (gj *GraphjinEngine) prepareRoleStmt() error { if !gj.abacEnabled { return nil } @@ -20,7 +20,7 @@ func (gj *graphjin) prepareRoleStmt() error { w := &bytes.Buffer{} io.WriteString(w, `SELECT (CASE WHEN EXISTS (`) - gj.pc.RenderVar(w, &gj.roleStmtMD, gj.conf.RolesQuery) + gj.psqlCompiler.RenderVar(w, &gj.roleStatementMetadata, gj.conf.RolesQuery) io.WriteString(w, `) THEN `) io.WriteString(w, `(SELECT (CASE`) @@ -36,7 +36,7 @@ func (gj *graphjin) prepareRoleStmt() error { } io.WriteString(w, ` ELSE 'user' END) FROM (`) - gj.pc.RenderVar(w, &gj.roleStmtMD, gj.conf.RolesQuery) + gj.psqlCompiler.RenderVar(w, &gj.roleStatementMetadata, gj.conf.RolesQuery) io.WriteString(w, `) AS _sg_auth_roles_query LIMIT 1) `) switch gj.dbtype { @@ -47,6 +47,6 @@ func (gj *graphjin) prepareRoleStmt() error { io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS _sg_auth_filler LIMIT 1; `) } - gj.roleStmt = w.String() + gj.roleStatement = w.String() return nil } diff --git a/core/schema.go b/core/schema.go index 938cd728..3ce5521d 100644 --- a/core/schema.go +++ b/core/schema.go @@ -82,6 +82,7 @@ type {{.Name}} {{end -}} ` +// writeSchema writes the schema to the given writer func writeSchema(s *sdata.DBInfo, out io.Writer) (err error) { fn := template.FuncMap{ "pascal": toPascalCase, @@ -104,6 +105,7 @@ func writeSchema(s *sdata.DBInfo, out io.Writer) (err error) { return } +// toPascalCase converts a string to pascal case func toPascalCase(text string) string { var sb strings.Builder for _, v := range strings.Fields(text) { @@ -115,6 +117,7 @@ func toPascalCase(text string) string { var dbTypeRe = regexp.MustCompile(`([a-zA-Z ]+)(\((.+)\))?`) +// parseDBType parses the db type string func parseDBType(name string) (res [2]string, err error) { v := dbTypeRe.FindStringSubmatch(name) if len(v) == 4 { diff --git a/core/subs.go b/core/subs.go index 6ee8520e..ed149a46 100644 --- a/core/subs.go +++ b/core/subs.go @@ -83,7 +83,7 @@ func (g *GraphJin) Subscribe( c context.Context, query string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (m *Member, err error) { // get the name, query vars h, err := graph.FastParse(query) @@ -91,7 +91,7 @@ func (g *GraphJin) Subscribe( return } - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) // create the request object r := gj.newGraphqlReq(rc, "subscription", h.Name, nil, vars) @@ -119,9 +119,9 @@ func (g *GraphJin) SubscribeByName( c context.Context, name string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (m *Member, err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) item, err := gj.allowList.GetByName(name, gj.prod) if err != nil { @@ -134,15 +134,16 @@ func (g *GraphJin) SubscribeByName( return } -func (gj *graphjin) subscribe(c context.Context, r graphqlReq) ( +// subscribe function is called on the graphjin struct to subscribe to a query. +func (gj *GraphjinEngine) subscribe(c context.Context, r GraphqlReq) ( m *Member, err error, ) { - if r.op != qcode.QTSubscription { + if r.operation != qcode.QTSubscription { return nil, errors.New("subscription: not a subscription query") } // transactions not supported with subscriptions - if r.rc != nil && r.rc.Tx != nil { + if r.requestconfig != nil && r.requestconfig.Tx != nil { return nil, errors.New("subscription: database transactions not supported") } @@ -189,7 +190,7 @@ func (gj *graphjin) subscribe(c context.Context, r graphqlReq) ( } m = &Member{ - ns: r.ns, + ns: r.namespace, id: atomic.AddUint64(&sub.idgen, 1), Result: make(chan *Result, 10), sub: sub, @@ -206,13 +207,14 @@ func (gj *graphjin) subscribe(c context.Context, r graphqlReq) ( return } -func (gj *graphjin) initSub(c context.Context, sub *sub) (err error) { +// initSub function is called on the graphjin struct to initialize a subscription. +func (gj *GraphjinEngine) initSub(c context.Context, sub *sub) (err error) { if err = sub.s.compile(); err != nil { return } if !gj.prod { - err = gj.saveToAllowList(sub.s.cs.st.qc, sub.s.r.ns) + err = gj.saveToAllowList(sub.s.cs.st.qc, sub.s.r.namespace) if err != nil { return } @@ -226,7 +228,8 @@ func (gj *graphjin) initSub(c context.Context, sub *sub) (err error) { return } -func (gj *graphjin) subController(sub *sub) { +// subController function is called on the graphjin struct to control the subscription. +func (gj *GraphjinEngine) subController(sub *sub) { // remove subscription if controller exists defer gj.subs.Delete(sub.k) @@ -264,6 +267,7 @@ func (gj *graphjin) subController(sub *sub) { } } +// addMember function is called on the sub struct to add a member. func (s *sub) addMember(m *Member) error { mi := minfo{cindx: m.cindx} if mi.cindx != -1 { @@ -294,6 +298,7 @@ func (s *sub) addMember(m *Member) error { return nil } +// deleteMember function is called on the sub struct to delete a member. func (s *sub) deleteMember(m *Member) { i, ok := s.findByID(m.id) if !ok { @@ -313,6 +318,7 @@ func (s *sub) deleteMember(m *Member) { s.ids = s.ids[:len(s.ids)-1] } +// updateMember function is called on the sub struct to update a member. func (s *sub) updateMember(msg mmsg) error { i, ok := s.findByID(msg.id) if !ok { @@ -338,7 +344,8 @@ func (s *sub) updateMember(msg mmsg) error { return nil } -func (s *sub) fanOutJobs(gj *graphjin) { +// fanOutJobs function is called on the sub struct to fan out jobs. +func (s *sub) fanOutJobs(gj *GraphjinEngine) { switch { case len(s.ids) == 0: return @@ -355,7 +362,8 @@ func (s *sub) fanOutJobs(gj *graphjin) { } } -func (gj *graphjin) subCheckUpdates(sub *sub, mv mval, start int) { +// subCheckUpdates function is called on the graphjin struct to check updates. +func (gj *GraphjinEngine) subCheckUpdates(sub *sub, mv mval, start int) { // Do not use the `mval` embedded inside sub since // its not thread safe use the copy `mv mval`. @@ -433,7 +441,8 @@ func (gj *graphjin) subCheckUpdates(sub *sub, mv mval, start int) { } } -func (gj *graphjin) subFirstQuery(sub *sub, m *Member) (mmsg, error) { +// subFirstQuery function is called on the graphjin struct to get the first query. +func (gj *GraphjinEngine) subFirstQuery(sub *sub, m *Member) (mmsg, error) { c := context.Background() // when params are not available we use a more optimized @@ -473,7 +482,8 @@ func (gj *graphjin) subFirstQuery(sub *sub, m *Member) (mmsg, error) { return mm, err } -func (gj *graphjin) subNotifyMember(s *sub, mv mval, j int, js json.RawMessage) { +// subNotifyMember function is called on the graphjin struct to notify a member. +func (gj *GraphjinEngine) subNotifyMember(s *sub, mv mval, j int, js json.RawMessage) { _, err := gj.subNotifyMemberEx(s, mv.mi[j].dh, mv.mi[j].cindx, @@ -484,7 +494,8 @@ func (gj *graphjin) subNotifyMember(s *sub, mv mval, j int, js json.RawMessage) } } -func (gj *graphjin) subNotifyMemberEx(sub *sub, +// subNotifyMemberEx function is called on the graphjin struct to notify a member. +func (gj *GraphjinEngine) subNotifyMemberEx(sub *sub, dh [32]byte, cindx int, id uint64, rc chan *Result, js json.RawMessage, update bool, ) (mmsg, error) { mm := mmsg{id: id} @@ -496,15 +507,15 @@ func (gj *graphjin) subNotifyMemberEx(sub *sub, nonce := mm.dh - if cv := firstCursorValue(js, gj.pf); len(cv) != 0 { + if cv := firstCursorValue(js, gj.printFormat); len(cv) != 0 { mm.cursor = string(cv) } ejs, err := encryptValues(js, - gj.pf, + gj.printFormat, decPrefix, nonce[:], - gj.encKey) + gj.encryptionKey) if err != nil { return mm, fmt.Errorf(errSubs, "cursor", err) } @@ -520,11 +531,11 @@ func (gj *graphjin) subNotifyMemberEx(sub *sub, } res := &Result{ - op: qcode.QTQuery, - name: sub.s.r.name, - sql: sub.s.cs.st.sql, - role: sub.s.cs.st.role, - Data: ejs, + operation: qcode.QTQuery, + name: sub.s.r.name, + sql: sub.s.cs.st.sql, + role: sub.s.cs.st.role, + Data: ejs, } // if parameters exists then each response is unique @@ -538,6 +549,7 @@ func (gj *graphjin) subNotifyMemberEx(sub *sub, return mm, nil } +// renderSubWrap function is called on the graphjin struct to render a sub wrap. func renderSubWrap(st stmt, ct string) string { var w strings.Builder @@ -577,6 +589,7 @@ func renderSubWrap(st stmt, ct string) string { return w.String() } +// renderJSONArray function is called on the graphjin struct to render a json array. func renderJSONArray(v []json.RawMessage) json.RawMessage { w := bytes.Buffer{} w.WriteRune('[') @@ -590,6 +603,7 @@ func renderJSONArray(v []json.RawMessage) json.RawMessage { return json.RawMessage(w.Bytes()) } +// findByID function is called on the sub struct to find a member by id. func (s *sub) findByID(id uint64) (int, bool) { for i := range s.ids { if s.ids[i] == id { @@ -599,6 +613,7 @@ func (s *sub) findByID(id uint64) (int, bool) { return 0, false } +// Unsubscribe function is called on the member struct to unsubscribe. func (m *Member) Unsubscribe() { if m != nil && !m.done { m.sub.del <- m @@ -606,10 +621,12 @@ func (m *Member) Unsubscribe() { } } +// ID function is called on the member struct to get the id. func (m *Member) ID() uint64 { return m.id } +// String function is called on the member struct to get the string. func (m *Member) String() string { return strconv.Itoa(int(m.id)) } diff --git a/core/trace.go b/core/trace.go index af947ae0..c182beb7 100644 --- a/core/trace.go +++ b/core/trace.go @@ -21,17 +21,21 @@ type tracer struct{} type span struct{} +// Start starts a new trace span func (t *tracer) Start(c context.Context, name string) (context.Context, Spaner) { return c, &span{} } +// NewHTTPClient creates a new HTTP client func (t *tracer) NewHTTPClient() *http.Client { return &http.Client{} } +// End ends the span func (s *span) End() { } +// Error logs an error func (s *span) Error(err error) { } @@ -40,9 +44,11 @@ type StringAttr struct { Value string } +// IsRecording returns true if the span is recording func (s *span) IsRecording() bool { return false } +// SetAttributesString sets the attributes func (s *span) SetAttributesString(attrs ...StringAttr) { } diff --git a/core/watcher.go b/core/watcher.go index 3f25351b..f1bb1478 100644 --- a/core/watcher.go +++ b/core/watcher.go @@ -6,8 +6,9 @@ import ( "github.com/dosco/graphjin/core/v3/internal/sdata" ) +// initDBWatcher initializes the database schema watcher func (g *GraphJin) initDBWatcher() error { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) // no schema polling in production if gj.prod { @@ -30,12 +31,13 @@ func (g *GraphJin) initDBWatcher() error { return nil } +// startDBWatcher starts the database schema watcher func (g *GraphJin) startDBWatcher(ps time.Duration) { ticker := time.NewTicker(ps) defer ticker.Stop() for range ticker.C { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) latestDi, err := sdata.GetDBInfo( gj.db, diff --git a/go.work.sum b/go.work.sum index 3c2e02de..c887ea90 100644 --- a/go.work.sum +++ b/go.work.sum @@ -676,6 +676,7 @@ github.com/fatih/color v1.14.1/go.mod h1:2oHN61fhTpgcxD3TSWCgKDiH1+x4OiDVVGH8Wlg github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-asn1-ber/asn1-ber v1.5.5 h1:MNHlNMBDgEKD4TcKr36vQN68BA00aDfjIt3/bD50WnA= github.com/go-asn1-ber/asn1-ber v1.5.5/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= diff --git a/serv/admin.go b/serv/admin.go index 89847e26..bc41437a 100644 --- a/serv/admin.go +++ b/serv/admin.go @@ -12,10 +12,11 @@ import ( "time" ) -func adminDeployHandler(s1 *Service) http.Handler { +// adminDeployHandler handles the admin deploy endpoint +func adminDeployHandler(s1 *HttpService) http.Handler { h := func(w http.ResponseWriter, r *http.Request) { var req DeployReq - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if !s.isAdminSecret(r) { authFail(w) @@ -57,9 +58,10 @@ func adminDeployHandler(s1 *Service) http.Handler { return http.HandlerFunc(h) } -func adminRollbackHandler(s1 *Service) http.Handler { +// adminRollbackHandler handles the admin rollback endpoint +func adminRollbackHandler(s1 *HttpService) http.Handler { h := func(w http.ResponseWriter, r *http.Request) { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if !s.isAdminSecret(r) { authFail(w) @@ -85,7 +87,8 @@ func adminRollbackHandler(s1 *Service) http.Handler { return http.HandlerFunc(h) } -func (s *service) isAdminSecret(r *http.Request) bool { +// adminConfigHandler handles the checking of the admin secret endpoint +func (s *graphjinService) isAdminSecret(r *http.Request) bool { atomic.AddInt32(&s.adminCount, 1) defer atomic.StoreInt32(&s.adminCount, 0) @@ -105,14 +108,17 @@ func (s *service) isAdminSecret(r *http.Request) bool { return (err == nil) && bytes.Equal(v1, s.asec[:]) } +// badReq sends a bad request response func badReq(w http.ResponseWriter, msg string) { http.Error(w, msg, http.StatusBadRequest) } +// intErr sends an internal server error response func intErr(w http.ResponseWriter, msg string) { http.Error(w, msg, http.StatusInternalServerError) } +// authFail sends an unauthorized response func authFail(w http.ResponseWriter) { http.Error(w, "auth failed", http.StatusUnauthorized) } diff --git a/serv/afero.go b/serv/afero.go index 10276062..428683c7 100644 --- a/serv/afero.go +++ b/serv/afero.go @@ -11,14 +11,17 @@ type aferoFS struct { fs afero.Fs } +// newAferoFS creates a new aferoFS instance func newAferoFS(fs afero.Fs, basePath string) *aferoFS { return &aferoFS{fs: afero.NewBasePathFs(fs, basePath)} } +// Get reads a file from the file system func (f *aferoFS) Get(path string) ([]byte, error) { return afero.ReadFile(f.fs, path) } +// Put writes a file to the file system func (f *aferoFS) Put(path string, data []byte) (err error) { dir := filepath.Dir(path) ok, err := f.Exists(dir) @@ -32,6 +35,7 @@ func (f *aferoFS) Put(path string, data []byte) (err error) { return afero.WriteFile(f.fs, path, data, os.ModePerm) } +// Exists checks if a file exists in the file system func (f *aferoFS) Exists(path string) (exists bool, err error) { return afero.Exists(f.fs, path) } diff --git a/serv/api.go b/serv/api.go index 184ac7b0..075e98cd 100644 --- a/serv/api.go +++ b/serv/api.go @@ -56,7 +56,7 @@ import ( "go.uber.org/zap/zapcore" ) -type Service struct { +type HttpService struct { atomic.Value opt []Option cpath string @@ -71,7 +71,7 @@ const ( type HookFn func(*core.Result) -type service struct { +type graphjinService struct { log *zap.SugaredLogger // logger zlog *zap.Logger // faster logger logLevel int // log level @@ -92,10 +92,10 @@ type service struct { tracer trace.Tracer } -type Option func(*service) error +type Option func(*graphjinService) error // NewGraphJinService a new service -func NewGraphJinService(conf *Config, options ...Option) (*Service, error) { +func NewGraphJinService(conf *Config, options ...Option) (*HttpService, error) { if conf.dirty { return nil, errors.New("do not re-use config object") } @@ -105,7 +105,7 @@ func NewGraphJinService(conf *Config, options ...Option) (*Service, error) { return nil, err } - s1 := &Service{opt: options, cpath: conf.Serv.ConfigPath} + s1 := &HttpService{opt: options, cpath: conf.Serv.ConfigPath} s1.Store(s) if s.conf.WatchAndReload { @@ -121,7 +121,7 @@ func NewGraphJinService(conf *Config, options ...Option) (*Service, error) { // OptionSetDB sets a new db client func OptionSetDB(db *sql.DB) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.db = db return nil } @@ -129,7 +129,7 @@ func OptionSetDB(db *sql.DB) Option { // OptionSetHookFunc sets a function to be called on every request func OptionSetHookFunc(fn HookFn) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.hook = fn return nil } @@ -137,7 +137,7 @@ func OptionSetHookFunc(fn HookFn) Option { // OptionSetNamespace sets service namespace func OptionSetNamespace(namespace string) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.namespace = &namespace return nil } @@ -145,7 +145,7 @@ func OptionSetNamespace(namespace string) Option { // OptionSetFS sets service filesystem func OptionSetFS(fs core.FS) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.fs = fs return nil } @@ -153,7 +153,7 @@ func OptionSetFS(fs core.FS) Option { // OptionSetZapLogger sets service structured logger func OptionSetZapLogger(zlog *zap.Logger) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.zlog = zlog s.log = zlog.Sugar() return nil @@ -162,13 +162,14 @@ func OptionSetZapLogger(zlog *zap.Logger) Option { // OptionDeployActive caused the active config to be deployed on func OptionDeployActive() Option { - return func(s *service) error { + return func(s *graphjinService) error { s.deployActive = true return nil } } -func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*service, error) { +// newGraphJinService creates a new service +func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*graphjinService, error) { var err error if conf == nil { conf = &Config{Core: Core{Debug: true}} @@ -178,7 +179,7 @@ func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*service, prod := conf.Serv.Production conf.Core.Production = prod - s := &service{ + s := &graphjinService{ conf: conf, zlog: zlog, log: zlog.Sugar(), @@ -224,7 +225,8 @@ func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*service, return s, nil } -func (s *service) normalStart() error { +// normalStart starts the service in normal mode +func (s *graphjinService) normalStart() error { opts := []core.Option{ core.OptionSetFS(s.fs), core.OptionSetTrace(otelPlugin.NewTracerFrom(s.tracer)), @@ -238,7 +240,8 @@ func (s *service) normalStart() error { return err } -func (s *service) hotStart() error { +// hotStart starts the service in hot-deploy mode +func (s *graphjinService) hotStart() error { ab, err := fetchActiveBundle(s.db) if err != nil { if strings.Contains(err.Error(), "_graphjin.") { @@ -251,7 +254,7 @@ func (s *service) hotStart() error { return s.normalStart() } - cf := s.conf.vi.ConfigFileUsed() + cf := s.conf.viper.ConfigFileUsed() cf = filepath.Base(strings.TrimSuffix(cf, filepath.Ext(cf))) cf = filepath.Join("/", cf) @@ -283,9 +286,9 @@ func (s *service) hotStart() error { } // Deploy a new configuration -func (s *Service) Deploy(conf *Config, options ...Option) error { +func (s *HttpService) Deploy(conf *Config, options ...Option) error { var err error - os := s.Load().(*service) + os := s.Load().(*graphjinService) if conf == nil { return nil @@ -304,27 +307,28 @@ func (s *Service) Deploy(conf *Config, options ...Option) error { } // Start the service listening on the configured port -func (s *Service) Start() error { +func (s *HttpService) Start() error { startHTTP(s) return nil } // Attach route to the internal http service -func (s *Service) Attach(mux Mux) error { +func (s *HttpService) Attach(mux Mux) error { return s.attach(mux, nil) } // AttachWithNS a namespaced route to the internal http service -func (s *Service) AttachWithNS(mux Mux, namespace string) error { +func (s *HttpService) AttachWithNS(mux Mux, namespace string) error { return s.attach(mux, &namespace) } -func (s *Service) attach(mux Mux, ns *string) error { +// attach attaches the service to the router +func (s *HttpService) attach(mux Mux, ns *string) error { if _, err := routesHandler(s, mux, ns); err != nil { return err } - s1 := s.Load().(*service) + s1 := s.Load().(*graphjinService) ver := version dep := s1.conf.name @@ -356,26 +360,26 @@ func (s *Service) attach(mux Mux, ns *string) error { } // GraphQLis the http handler the GraphQL endpoint -func (s *Service) GraphQL(ah auth.HandlerFunc) http.Handler { +func (s *HttpService) GraphQL(ah auth.HandlerFunc) http.Handler { return s.apiHandler(nil, ah, false) } // GraphQLWithNS is the http handler the namespaced GraphQL endpoint -func (s *Service) GraphQLWithNS(ah auth.HandlerFunc, ns string) http.Handler { +func (s *HttpService) GraphQLWithNS(ah auth.HandlerFunc, ns string) http.Handler { return s.apiHandler(&ns, ah, false) } // REST is the http handler the REST endpoint -func (s *Service) REST(ah auth.HandlerFunc) http.Handler { +func (s *HttpService) REST(ah auth.HandlerFunc) http.Handler { return s.apiHandler(nil, ah, true) } // RESTWithNS is the http handler the namespaced REST endpoint -func (s *Service) RESTWithNS(ah auth.HandlerFunc, ns string) http.Handler { +func (s *HttpService) RESTWithNS(ah auth.HandlerFunc, ns string) http.Handler { return s.apiHandler(&ns, ah, true) } -func (s *Service) apiHandler(ns *string, ah auth.HandlerFunc, rest bool) http.Handler { +func (s *HttpService) apiHandler(ns *string, ah auth.HandlerFunc, rest bool) http.Handler { var h http.Handler if rest { h = s.apiV1Rest(ns, ah) @@ -386,32 +390,34 @@ func (s *Service) apiHandler(ns *string, ah auth.HandlerFunc, rest bool) http.Ha } // WebUI is the http handler the web ui endpoint -func (s *Service) WebUI(routePrefix, gqlEndpoint string) http.Handler { +func (s *HttpService) WebUI(routePrefix, gqlEndpoint string) http.Handler { return webuiHandler(routePrefix, gqlEndpoint) } // GetGraphJin fetching internal GraphJin core -func (s *Service) GetGraphJin() *core.GraphJin { - s1 := s.Load().(*service) +func (s *HttpService) GetGraphJin() *core.GraphJin { + s1 := s.Load().(*graphjinService) return s1.gj } // GetDB fetching internal db client -func (s *Service) GetDB() *sql.DB { - s1 := s.Load().(*service) +func (s *HttpService) GetDB() *sql.DB { + s1 := s.Load().(*graphjinService) return s1.db } // Reload re-runs database discover and reinitializes service. -func (s *Service) Reload() error { - s1 := s.Load().(*service) +func (s *HttpService) Reload() error { + s1 := s.Load().(*graphjinService) return s1.gj.Reload() } -func (s *service) spanStart(c context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { +// spanStart starts the tracer +func (s *graphjinService) spanStart(c context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { return s.tracer.Start(c, name, opts...) } +// spanError records an error in the span func spanError(span trace.Span, err error) { if span.IsRecording() { span.RecordError(err) diff --git a/serv/client.go b/serv/client.go index 2b8a0422..84ed0fd4 100644 --- a/serv/client.go +++ b/serv/client.go @@ -39,6 +39,7 @@ type Resp struct { Msg string } +// NewAdminClient creates a new admin client func NewAdminClient(host string, secret string) *Client { c := resty.New(). SetBaseURL(host). @@ -67,6 +68,7 @@ func NewAdminClient(host string, secret string) *Client { return &Client{c} } +// Deploy deploys the configuration to the server func (c *Client) Deploy(name, confPath string) (*Resp, error) { errMsg := "deploy failed: %w" @@ -85,6 +87,7 @@ func (c *Client) Deploy(name, confPath string) (*Resp, error) { return &Resp{Msg: string(res.Body())}, nil } +// Rollback rolls back the last deployment func (c *Client) Rollback() (*Resp, error) { errMsg := "rollback failed: %w" @@ -97,6 +100,7 @@ func (c *Client) Rollback() (*Resp, error) { return &Resp{Msg: string(res.Body())}, nil } +// buildBundle creates a zip archive of the configuration directory func buildBundle(confPath string) (string, error) { buf := bytes.Buffer{} z := zip.NewWriter(&buf) diff --git a/serv/config.go b/serv/config.go index 7ac30e19..eac4d580 100644 --- a/serv/config.go +++ b/serv/config.go @@ -37,7 +37,7 @@ type Config struct { hash string name string dirty bool - vi *viper.Viper + viper *viper.Viper } // Configuration for admin service @@ -239,23 +239,24 @@ func ReadInConfig(configFile string) (*Config, error) { // ReadInConfigFS is the same as ReadInConfig but it also takes a filesytem as an argument func ReadInConfigFS(configFile string, fs afero.Fs) (*Config, error) { - c, err := readInConfig(configFile, fs) + config, err := readInConfig(configFile, fs) if err != nil { return nil, err } - c1, err := setupSecrets(c, fs) + secrets, err := setupSecrets(config, fs) if err != nil { - return nil, fmt.Errorf("%w: %s", err, c.SecretsFile) + return nil, fmt.Errorf("%w: %s", err, config.SecretsFile) } - return c1, err + return secrets, err } +// setupSecrets function reads in the secrets file and merges the secrets into the config func setupSecrets(conf *Config, fs afero.Fs) (*Config, error) { if conf.SecretsFile == "" { return conf, nil } - secFile, err := filepath.Abs(conf.RelPath(conf.SecretsFile)) + secFile, err := filepath.Abs(conf.AbsolutePath(conf.SecretsFile)) if err != nil { return nil, err } @@ -267,15 +268,15 @@ func setupSecrets(conf *Config, fs afero.Fs) (*Config, error) { return nil, err } - for k, v := range newConf.secrets { - util.SetKeyValue(conf.vi, k, v) + for secretKey, secretValue := range newConf.secrets { + util.SetKeyValue(conf.viper, secretKey, secretValue) } if len(newConf.secrets) == 0 { return conf, nil } - if err := conf.vi.Unmarshal(&newConf); err != nil { + if err := conf.viper.Unmarshal(&newConf); err != nil { return nil, fmt.Errorf("failed to decode config, %v", err) } @@ -289,35 +290,37 @@ func setupSecrets(conf *Config, fs afero.Fs) (*Config, error) { return &newConf, nil } +// readInConfig function reads in the config file for the environment specified in the GO_ENV func readInConfig(configFile string, fs afero.Fs) (*Config, error) { cp := filepath.Dir(configFile) - vi := newViper(cp, filepath.Base(configFile)) + viper := newViper(cp, filepath.Base(configFile)) if fs != nil { - vi.SetFs(fs) + viper.SetFs(fs) } - if err := vi.ReadInConfig(); err != nil { + + if err := viper.ReadInConfig(); err != nil { return nil, err } - if pcf := vi.GetString("inherits"); pcf != "" { - cf := vi.ConfigFileUsed() - vi = newViper(cp, pcf) + if pcf := viper.GetString("inherits"); pcf != "" { + cf := viper.ConfigFileUsed() + viper = newViper(cp, pcf) if fs != nil { - vi.SetFs(fs) + viper.SetFs(fs) } - if err := vi.ReadInConfig(); err != nil { + if err := viper.ReadInConfig(); err != nil { return nil, err } - if v := vi.GetString("inherits"); v != "" { - return nil, fmt.Errorf("inherited config '%s' cannot itself inherit '%s'", pcf, v) + if value := viper.GetString("inherits"); value != "" { + return nil, fmt.Errorf("inherited config '%s' cannot itself inherit '%s'", pcf, value) } - vi.SetConfigFile(cf) + viper.SetConfigFile(cf) - if err := vi.MergeInConfig(); err != nil { + if err := viper.MergeInConfig(); err != nil { return nil, err } } @@ -325,20 +328,21 @@ func readInConfig(configFile string, fs afero.Fs) (*Config, error) { for _, e := range os.Environ() { if strings.HasPrefix(e, "GJ_") || strings.HasPrefix(e, "SJ_") { kv := strings.SplitN(e, "=", 2) - util.SetKeyValue(vi, kv[0], kv[1]) + util.SetKeyValue(viper, kv[0], kv[1]) } } - c := &Config{vi: vi} - c.Serv.ConfigPath = cp + config := &Config{viper: viper} + config.Serv.ConfigPath = cp - if err := vi.Unmarshal(c); err != nil { + if err := viper.Unmarshal(&config); err != nil { return nil, fmt.Errorf("failed to decode config, %v", err) } - return c, nil + return config, nil } +// NewConfig function creates a new GraphJin configuration from the provided config string func NewConfig(config, format string) (*Config, error) { if format == "" { format = "yaml" @@ -355,22 +359,23 @@ func NewConfig(config, format string) (*Config, error) { } } - vi := newViperWithDefaults() - vi.SetConfigType(format) + viper := newViperWithDefaults() + viper.SetConfigType(format) - if err := vi.ReadConfig(strings.NewReader(config)); err != nil { + if err := viper.ReadConfig(strings.NewReader(config)); err != nil { return nil, err } - c := &Config{vi: vi} + c := &Config{viper: viper} - if err := vi.Unmarshal(c); err != nil { + if err := viper.Unmarshal(&c); err != nil { return nil, fmt.Errorf("failed to decode config, %v", err) } return c, nil } +// newViperWithDefaults returns a new viper instance with the default settings func newViperWithDefaults() *viper.Viper { vi := viper.New() @@ -406,6 +411,7 @@ func newViperWithDefaults() *viper.Viper { return vi } +// newViper returns a new viper instance with the default settings func newViper(configPath, configFile string) *viper.Viper { vi := newViperWithDefaults() vi.SetConfigName(strings.TrimSuffix(configFile, filepath.Ext(configFile))) @@ -419,45 +425,54 @@ func newViper(configPath, configFile string) *viper.Viper { return vi } -func (c *Config) GetSecret(k string) (string, bool) { - v, ok := c.secrets[k] - return v, ok +// GetSecret returns the value of the secret key +// if it exists +func (c *Config) GetSecret(key string) (string, bool) { + value, ok := c.secrets[key] + return value, ok } -func (c *Config) GetSecretOrEnv(k string) string { - if v, ok := c.GetSecret(k); ok { - return v +// GetSecretOrEnv returns the value of the secret key if +// it exists or the value of the environment variable +func (c *Config) GetSecretOrEnv(key string) string { + if value, ok := c.GetSecret(key); ok { + return value } - return os.Getenv(k) + return os.Getenv(key) } // func (c *Config) telemetryEnabled() bool { // return c.Telemetry.Debug || c.Telemetry.Metrics.Exporter != "" || c.Telemetry.Tracing.Exporter != "" // } -func (c *Config) RelPath(p string) string { +// AbsolutePath returns the absolute path of the file +func (c *Config) AbsolutePath(p string) string { if filepath.IsAbs(p) { return p } return filepath.Join(c.Serv.ConfigPath, p) } +// SetHash sets the hash value of the configuration func (c *Config) SetHash(hash string) { c.hash = hash } +// SetName sets the name of the configuration func (c *Config) SetName(name string) { c.name = name } +// rateLimiterEnable returns true if the rate limiter is enabled func (c *Config) rateLimiterEnable() bool { return c.RateLimiter.Rate > 0 && c.RateLimiter.Bucket > 0 } +// GetConfigName returns the name of the configuration func GetConfigName() string { - ge := strings.TrimSpace(strings.ToLower(os.Getenv("GO_ENV"))) + goEnv := strings.TrimSpace(strings.ToLower(os.Getenv("GO_ENV"))) - switch ge { + switch goEnv { case "production", "prod": return "prod" @@ -471,6 +486,6 @@ func GetConfigName() string { return "dev" default: - return ge + return goEnv } } diff --git a/serv/db.go b/serv/db.go index f0e122c3..3de28687 100644 --- a/serv/db.go +++ b/serv/db.go @@ -34,10 +34,12 @@ type dbConf struct { connString string } +// Config holds the configuration for the service func NewDB(conf *Config, openDB bool, log *zap.SugaredLogger, fs core.FS) (*sql.DB, error) { return newDB(conf, openDB, false, log, fs) } +// newDB initializes the database func newDB( conf *Config, openDB, useTelemetry bool, @@ -97,43 +99,44 @@ func newDB( } } +// initPostgres initializes the postgres database func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, error) { - c := conf - config, _ := pgx.ParseConfig(c.DB.ConnString) - if c.DB.Host != "" { - config.Host = c.DB.Host + confCopy := conf + config, _ := pgx.ParseConfig(confCopy.DB.ConnString) + if confCopy.DB.Host != "" { + config.Host = confCopy.DB.Host } - if c.DB.Port != 0 { - config.Port = c.DB.Port + if confCopy.DB.Port != 0 { + config.Port = confCopy.DB.Port } - if c.DB.User != "" { - config.User = c.DB.User + if confCopy.DB.User != "" { + config.User = confCopy.DB.User } - if c.DB.Password != "" { - config.Password = c.DB.Password + if confCopy.DB.Password != "" { + config.Password = confCopy.DB.Password } if config.RuntimeParams == nil { config.RuntimeParams = map[string]string{} } - if c.DB.Schema != "" { - config.RuntimeParams["search_path"] = c.DB.Schema + if confCopy.DB.Schema != "" { + config.RuntimeParams["search_path"] = confCopy.DB.Schema } - if c.AppName != "" { - config.RuntimeParams["application_name"] = c.AppName + if confCopy.AppName != "" { + config.RuntimeParams["application_name"] = confCopy.AppName } - // if openDB { - config.Database = c.DB.DBName - // } + if openDB { + config.Database = confCopy.DB.DBName + } - if c.DB.EnableTLS { - if len(c.DB.ServerName) == 0 { + if confCopy.DB.EnableTLS { + if len(confCopy.DB.ServerName) == 0 { return nil, errors.New("tls: server_name is required") } - if len(c.DB.ServerCert) == 0 { + if len(confCopy.DB.ServerCert) == 0 { return nil, errors.New("tls: server_cert is required") } @@ -141,10 +144,10 @@ func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, var pem []byte var err error - if strings.Contains(c.DB.ServerCert, pemSig) { - pem = []byte(strings.ReplaceAll(c.DB.ServerCert, `\n`, "\n")) + if strings.Contains(confCopy.DB.ServerCert, pemSig) { + pem = []byte(strings.ReplaceAll(confCopy.DB.ServerCert, `\n`, "\n")) } else { - pem, err = fs.Get(c.DB.ServerCert) + pem, err = fs.Get(confCopy.DB.ServerCert) } if err != nil { @@ -158,24 +161,24 @@ func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, config.TLSConfig = &tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: rootCertPool, - ServerName: c.DB.ServerName, + ServerName: confCopy.DB.ServerName, } - if len(c.DB.ClientCert) > 0 { - if len(c.DB.ClientKey) == 0 { + if len(confCopy.DB.ClientCert) > 0 { + if len(confCopy.DB.ClientKey) == 0 { return nil, errors.New("tls: client_key is required") } clientCert := make([]tls.Certificate, 0, 1) var certs tls.Certificate - if strings.Contains(c.DB.ClientCert, pemSig) { + if strings.Contains(confCopy.DB.ClientCert, pemSig) { certs, err = tls.X509KeyPair( - []byte(strings.ReplaceAll(c.DB.ClientCert, `\n`, "\n")), - []byte(strings.ReplaceAll(c.DB.ClientKey, `\n`, "\n")), + []byte(strings.ReplaceAll(confCopy.DB.ClientCert, `\n`, "\n")), + []byte(strings.ReplaceAll(confCopy.DB.ClientKey, `\n`, "\n")), ) } else { - certs, err = loadX509KeyPair(fs, c.DB.ClientCert, c.DB.ClientKey) + certs, err = loadX509KeyPair(fs, confCopy.DB.ClientCert, confCopy.DB.ClientKey) } if err != nil { @@ -190,6 +193,7 @@ func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, return &dbConf{"pgx", stdlib.RegisterConnConfig(config)}, nil } +// initMysql initializes the mysql database func initMysql(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, error) { var connString string c := conf @@ -207,6 +211,7 @@ func initMysql(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, er return &dbConf{"mysql", connString}, nil } +// loadX509KeyPair loads a X509 key pair from a file system func loadX509KeyPair(fs core.FS, certFile, keyFile string) ( cert tls.Certificate, err error, ) { diff --git a/serv/deploy.go b/serv/deploy.go index c340a94a..ae7a26b5 100644 --- a/serv/deploy.go +++ b/serv/deploy.go @@ -23,7 +23,8 @@ type depResp struct { name, pname string } -func (s *service) saveConfig(c context.Context, name, bundle string) (*depResp, error) { +// saveConfig saves the config to the database +func (s *graphjinService) saveConfig(c context.Context, name, bundle string) (*depResp, error) { var dres depResp zip, err := base64.StdEncoding.DecodeString(bundle) @@ -139,7 +140,8 @@ func (s *service) saveConfig(c context.Context, name, bundle string) (*depResp, return &dres, nil } -func (s *service) rollbackConfig(c context.Context) (*depResp, error) { +// rollbackConfig rolls back the config to the previous one +func (s *graphjinService) rollbackConfig(c context.Context) (*depResp, error) { var dres depResp opt := &sql.TxOptions{Isolation: sql.LevelSerializable} @@ -216,6 +218,7 @@ type adminParams struct { params map[string]string } +// getAdminParams fetches the admin params from the database func getAdminParams(tx *sql.Tx) (adminParams, error) { var ap adminParams @@ -259,14 +262,15 @@ func getAdminParams(tx *sql.Tx) (adminParams, error) { return ap, nil } -func startHotDeployWatcher(s1 *Service) error { +// startHotDeployWatcher starts the hot deploy watcher +func startHotDeployWatcher(s1 *HttpService) error { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for range ticker.C { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) - cf := s.conf.vi.ConfigFileUsed() + cf := s.conf.viper.ConfigFileUsed() cf = filepath.Join("/", filepath.Base(strings.TrimSuffix(cf, filepath.Ext(cf)))) var id int @@ -322,6 +326,7 @@ type activeBundle struct { name, hash, bundle string } +// fetchActiveBundle fetches the active bundle from the database func fetchActiveBundle(db *sql.DB) (*activeBundle, error) { var ab activeBundle @@ -346,7 +351,8 @@ func fetchActiveBundle(db *sql.DB) (*activeBundle, error) { return &ab, nil } -func deployBundle(s1 *Service, name, hash, confFile, bundle string) error { +// deployBundle deploys the bundle to the server +func deployBundle(s1 *HttpService, name, hash, confFile, bundle string) error { bfs, err := bundle2Fs(name, hash, confFile, bundle) if err != nil { return err @@ -360,6 +366,7 @@ type bundleFs struct { fs afero.Fs } +// bundle2Fs converts the bundle to a filesystem func bundle2Fs(name, hash, confFile, bundle string) (bundleFs, error) { var bfs bundleFs diff --git a/serv/filewatch.go b/serv/filewatch.go index 6546351e..29e0f36e 100644 --- a/serv/filewatch.go +++ b/serv/filewatch.go @@ -13,7 +13,8 @@ import ( "github.com/pkg/errors" ) -func startConfigWatcher(s1 *Service) error { +// startConfigWatcher watches for changes in the config file +func startConfigWatcher(s1 *HttpService) error { var watcher *fsnotify.Watcher var err error @@ -59,7 +60,7 @@ func startConfigWatcher(s1 *Service) error { } for { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) select { case err := <-watcher.Errors: @@ -88,7 +89,7 @@ func startConfigWatcher(s1 *Service) error { } // Check if new config is valid - cf := s.conf.RelPath(GetConfigName()) + cf := s.conf.AbsolutePath(GetConfigName()) conf, err := readInConfig(cf, nil) if err != nil { s.log.Error(err) diff --git a/serv/health.go b/serv/health.go index ca7b967b..837df682 100644 --- a/serv/health.go +++ b/serv/health.go @@ -10,9 +10,10 @@ import ( var healthyResponse = []byte("All's Well") -func healthCheckHandler(s1 *Service) http.Handler { +// healthCheckHandler returns a handler that checks the health of the service +func healthCheckHandler(s1 *HttpService) http.Handler { h := func(w http.ResponseWriter, r *http.Request) { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) c, cancel := context.WithTimeout(r.Context(), s.conf.DB.PingTimeout) defer cancel() diff --git a/serv/http.go b/serv/http.go index 2cbab5ee..474becc8 100644 --- a/serv/http.go +++ b/serv/http.go @@ -55,9 +55,10 @@ type errorResp struct { Errors []string `json:"errors"` } -func apiV1Handler(s1 *Service, ns *string, h http.Handler, ah auth.HandlerFunc) http.Handler { +// apiV1Handler is the main handler for all API requests +func apiV1Handler(s1 *HttpService, ns *string, h http.Handler, ah auth.HandlerFunc) http.Handler { var zlog *zap.Logger - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if s.conf.Core.Debug { zlog = s.zlog @@ -107,14 +108,15 @@ func apiV1Handler(s1 *Service, ns *string, h http.Handler, ah auth.HandlerFunc) return h } -func (s1 *Service) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { +// apiV1GraphQLHandler handles the GraphQL API requests +func (s1 *HttpService) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { dtrace := otel.GetTextMapPropagator() h := func(w http.ResponseWriter, r *http.Request) { var err error start := time.Now() - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) w.Header().Set("Content-Type", "application/json") @@ -155,7 +157,7 @@ func (s1 *Service) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { return } - var rc core.ReqConfig + var rc core.RequestConfig if req.apqEnabled() { rc.APQKey = (req.OpName + req.Ext.Persisted.Sha256Hash) @@ -205,7 +207,8 @@ func (s1 *Service) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { return http.HandlerFunc(h) } -func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { +// apiV1Rest returns a handler that handles the REST API requests +func (s1 *HttpService) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { rLen := len(routeREST) dtrace := otel.GetTextMapPropagator() @@ -213,7 +216,7 @@ func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { var err error start := time.Now() - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) w.Header().Set("Content-Type", "application/json") @@ -255,7 +258,7 @@ func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { return } - var rc core.ReqConfig + var rc core.RequestConfig if rc.Vars == nil && len(s.conf.Core.HeaderVars) != 0 { rc.Vars = s.setHeaderVars(r) @@ -288,11 +291,12 @@ func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { return http.HandlerFunc(h) } -func (s *service) responseHandler(ct context.Context, +// responseHandler handles the response from the GraphQL API +func (s *graphjinService) responseHandler(ct context.Context, w http.ResponseWriter, r *http.Request, start time.Time, - rc core.ReqConfig, + rc core.RequestConfig, res *core.Result, err error, ) { @@ -330,7 +334,8 @@ func (s *service) responseHandler(ct context.Context, } } -func (s *service) reqLog(res *core.Result, rc core.ReqConfig, resTimeMs int64, err error) { +// reqLog logs the request details +func (s *graphjinService) reqLog(res *core.Result, rc core.RequestConfig, resTimeMs int64, err error) { var fields []zapcore.Field var sql string @@ -373,7 +378,8 @@ func (s *service) reqLog(res *core.Result, rc core.ReqConfig, resTimeMs int64, e } } -func (s *service) setHeaderVars(r *http.Request) map[string]interface{} { +// setHeaderVars sets the header variables +func (s *graphjinService) setHeaderVars(r *http.Request) map[string]interface{} { vars := make(map[string]interface{}) for k, v := range s.conf.Core.HeaderVars { vars[k] = func() string { @@ -386,11 +392,12 @@ func (s *service) setHeaderVars(r *http.Request) map[string]interface{} { return vars } +// apqEnabled checks if the APQ is enabled func (r gqlReq) apqEnabled() bool { return r.Ext.Persisted.Sha256Hash != "" } -// nolint:errcheck +// renderErr renders the error response func renderErr(w http.ResponseWriter, err error) { if err == errUnauthorized { w.WriteHeader(http.StatusUnauthorized) @@ -402,6 +409,7 @@ func renderErr(w http.ResponseWriter, err error) { } } +// parseBody parses the request body func parseBody(r *http.Request) ([]byte, error) { b, err := io.ReadAll(io.LimitReader(r.Body, maxReadBytes)) if err != nil { @@ -411,6 +419,7 @@ func parseBody(r *http.Request) ([]byte, error) { return b, nil } +// newDTrace creates a new DTrace func newDTrace(dtrace propagation.TextMapPropagator, r *http.Request) (context.Context, []trace.SpanStartOption) { ctx := dtrace.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) diff --git a/serv/init.go b/serv/init.go index 361e56de..2de22bc9 100644 --- a/serv/init.go +++ b/serv/init.go @@ -10,7 +10,8 @@ import ( "github.com/dosco/graphjin/core/v3" ) -func initLogLevel(s *service) { +// initLogLevel initializes the log level +func initLogLevel(s *graphjinService) { switch s.conf.LogLevel { case "debug": s.logLevel = logLevelDebug @@ -25,7 +26,8 @@ func initLogLevel(s *service) { } } -func validateConf(s *service) { +// validateConf validates the configuration +func validateConf(s *graphjinService) { var anonFound bool for _, r := range s.conf.Core.Roles { @@ -40,7 +42,8 @@ func validateConf(s *service) { } } -func (s *service) initFS() error { +// initFS initializes the file system +func (s *graphjinService) initFS() error { basePath, err := s.basePath() if err != nil { return err @@ -53,7 +56,8 @@ func (s *service) initFS() error { return nil } -func (s *service) initConfig() error { +// initConfig initializes the configuration +func (s *graphjinService) initConfig() error { c := s.conf c.dirty = true @@ -96,7 +100,8 @@ func (s *service) initConfig() error { return nil } -func (s *service) initDB() error { +// initDB initializes the database +func (s *graphjinService) initDB() error { var err error if s.db != nil { @@ -110,7 +115,8 @@ func (s *service) initDB() error { return nil } -func (s *service) basePath() (string, error) { +// basePath returns the base path +func (s *graphjinService) basePath() (string, error) { if s.conf.Serv.ConfigPath == "" { if cp, err := os.Getwd(); err == nil { return filepath.Join(cp, "config"), nil diff --git a/serv/internal/secrets/decrypt.go b/serv/internal/secrets/decrypt.go index f2b9923e..6dcf3579 100644 --- a/serv/internal/secrets/decrypt.go +++ b/serv/internal/secrets/decrypt.go @@ -22,6 +22,7 @@ type decryptOpts struct { KeyServices []keyservice.KeyServiceClient } +// decrypt decrypts the file at the given path using options passed. func decrypt(opts decryptOpts, fs afero.Fs) (decryptedFile []byte, err error) { tree, err := LoadEncryptedFileWithBugFixes(common.GenericDecryptOpts{ Cipher: opts.Cipher, @@ -84,6 +85,7 @@ func extract(tree *sops.Tree, path []interface{}, outputStore sops.Store) ([]byt return bytes, nil } +// LoadEncryptedFileWithBugFixes loads an encrypted file from the given path and applies bug fixes. func LoadEncryptedFileWithBugFixes( opts common.GenericDecryptOpts, fs afero.Fs) (*sops.Tree, error) { @@ -112,6 +114,7 @@ func LoadEncryptedFileWithBugFixes( return tree, nil } +// LoadEncryptedFile loads an encrypted file from the given path. func LoadEncryptedFile( loader sops.EncryptedFileLoader, inputPath string, diff --git a/serv/internal/secrets/edit.go b/serv/internal/secrets/edit.go index 7c03b5d2..c67726cd 100644 --- a/serv/internal/secrets/edit.go +++ b/serv/internal/secrets/edit.go @@ -56,6 +56,7 @@ GJ_ADMIN_SECRET_KEY: hotdeploy_admin_secret_key GJ_SECRET_KEY: graphjin_generic_secret_key GJ_AUTH_JWT_SECRET: jwt_auth_secret_key` +// editExample edits the example file func editExample(opts editExampleOpts) ([]byte, error) { branches, err := opts.InputStore.LoadPlainFile([]byte(fileBytes)) if err != nil { @@ -87,6 +88,7 @@ func editExample(opts editExampleOpts) ([]byte, error) { return editTree(opts.editOpts, &tree, dataKey) } +// edit edits the file at the given path using options passed. func edit(opts editOpts) ([]byte, error) { // Load the file tree, err := common.LoadEncryptedFileWithBugFixes(common.GenericDecryptOpts{ @@ -113,6 +115,7 @@ func edit(opts editOpts) ([]byte, error) { return editTree(opts, tree, dataKey) } +// editTree edits the tree using the options passed. func editTree(opts editOpts, tree *sops.Tree, dataKey []byte) ([]byte, error) { // Create temporary file for editing tmpdir, err := os.MkdirTemp("", "") @@ -180,6 +183,7 @@ func editTree(opts editOpts, tree *sops.Tree, dataKey []byte) ([]byte, error) { return encryptedFile, nil } +// runEditorUntilOk runs the editor until the file is saved and the hash is different func runEditorUntilOk(opts runEditorUntilOkOpts) error { for { err := runEditor(opts.TmpFile.Name()) @@ -240,6 +244,7 @@ func runEditorUntilOk(opts runEditorUntilOkOpts) error { return nil } +// hashFile returns the MD5 hash of the file at the given path func hashFile(filePath string) ([]byte, error) { var result []byte file, err := os.Open(filePath) @@ -254,6 +259,7 @@ func hashFile(filePath string) ([]byte, error) { return hash.Sum(result), nil } +// runEditor runs the editor func runEditor(path string) error { editor := os.Getenv("EDITOR") var cmd *exec.Cmd @@ -279,6 +285,7 @@ func runEditor(path string) error { return cmd.Run() } +// lookupAnyEditor looks up the first available editor func lookupAnyEditor(editorNames ...string) (editorPath string, err error) { for _, editorName := range editorNames { editorPath, err = exec.LookPath(editorName) diff --git a/serv/internal/secrets/init.go b/serv/internal/secrets/init.go index 9113af11..6efdfd9a 100644 --- a/serv/internal/secrets/init.go +++ b/serv/internal/secrets/init.go @@ -11,11 +11,8 @@ import ( "go.mozilla.org/sops/v3/stores/dotenv" ) +// Init reads the secrets from the given file and returns them as a map func Init(filename string, fs afero.Fs) (map[string]string, error) { - return initSecrets(filename, fs) -} - -func initSecrets(filename string, fs afero.Fs) (map[string]string, error) { var err error inputStore := common.DefaultStoreForPath(filename) diff --git a/serv/internal/secrets/rotate.go b/serv/internal/secrets/rotate.go index 846cfa3a..1a3a0512 100644 --- a/serv/internal/secrets/rotate.go +++ b/serv/internal/secrets/rotate.go @@ -24,6 +24,7 @@ type rotateOpts struct { KeyServices []keyservice.KeyServiceClient } +// rotate rotates the keys in the file at the given path using options passed. func rotate(opts rotateOpts) ([]byte, error) { tree, err := common.LoadEncryptedFileWithBugFixes(common.GenericDecryptOpts{ Cipher: opts.Cipher, diff --git a/serv/internal/secrets/run.go b/serv/internal/secrets/run.go index d2e6a420..8c8c24e0 100644 --- a/serv/internal/secrets/run.go +++ b/serv/internal/secrets/run.go @@ -21,6 +21,7 @@ type SecretArgs struct { KMS, KMSC, AWS, GCP, Azure, PGP string //nolint:golint,unused } +// SecretsCmd is the entry point for the secrets command func SecretsCmd(cmdName, fileName string, sa SecretArgs, args []string, log *zap.SugaredLogger) error { var err error @@ -148,6 +149,7 @@ func SecretsCmd(cmdName, fileName string, sa SecretArgs, args []string, log *zap return nil } +// keyGroups returns a slice of key groups based on the secret arguments func keyGroups(sa SecretArgs, file string) ([]sops.KeyGroup, error) { var kmsKeys []keys.MasterKey var pgpKeys []keys.MasterKey diff --git a/serv/internal/util/log.go b/serv/internal/util/log.go index 7c0c8774..262ea160 100644 --- a/serv/internal/util/log.go +++ b/serv/internal/util/log.go @@ -7,6 +7,8 @@ import ( "go.uber.org/zap/zapcore" ) +// NewLogger creates a new zap logger instance +// json - if true logs are in json format func NewLogger(json bool) *zap.Logger { econf := zapcore.EncoderConfig{ MessageKey: "msg", diff --git a/serv/internal/util/viper.go b/serv/internal/util/viper.go index 5621b87d..81c3ebc8 100644 --- a/serv/internal/util/viper.go +++ b/serv/internal/util/viper.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/viper" ) +// SetKeyValue sets the value of a key in the viper config func SetKeyValue(vi *viper.Viper, key string, value interface{}) bool { if strings.HasPrefix(key, "GJ_") || strings.HasPrefix(key, "SG_") { key = key[3:] diff --git a/serv/iplimiter.go b/serv/iplimiter.go index 7c648d8b..513a0466 100644 --- a/serv/iplimiter.go +++ b/serv/iplimiter.go @@ -14,10 +14,12 @@ import ( var ipCache cache.Cache +// init initializes the cache func init() { ipCache, _ = cache.NewCache(cache.MaxKeys(10), cache.TTL(time.Minute*5)) } +// getIPLimiter returns the rate limiter for the given IP func getIPLimiter(ip string, limit float64, bucket int) *rate.Limiter { v, exists := ipCache.Get(ip) if !exists { @@ -29,11 +31,12 @@ func getIPLimiter(ip string, limit float64, bucket int) *rate.Limiter { return v.(*rate.Limiter) } -func rateLimiter(s1 *Service, h http.Handler) http.Handler { +// rateLimiter is a middleware that limits the number of requests per IP +func rateLimiter(s1 *HttpService, h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { var iph, ip string var err error - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if s.conf.RateLimiter.IPHeader != "" { iph = r.Header.Get(s.conf.RateLimiter.IPHeader) diff --git a/serv/migrate.go b/serv/migrate.go index e3aa920d..601c53eb 100644 --- a/serv/migrate.go +++ b/serv/migrate.go @@ -33,6 +33,7 @@ CREATE TABLE _graphjin.configs ( CREATE INDEX config_active ON _graphjin.configs (active); ` +// InitAdmin creates the admin tables func InitAdmin(db *sql.DB, dbtype string) error { c := context.Background() @@ -52,6 +53,7 @@ func InitAdmin(db *sql.DB, dbtype string) error { return nil } +// idColSql returns the id column sql func idColSql(dbtype string) string { switch dbtype { case "mysql": diff --git a/serv/routes.go b/serv/routes.go index db7bd1ca..cebd8f57 100644 --- a/serv/routes.go +++ b/serv/routes.go @@ -17,8 +17,9 @@ type Mux interface { ServeHTTP(http.ResponseWriter, *http.Request) } -func routesHandler(s1 *Service, mux Mux, ns *string) (http.Handler, error) { - s := s1.Load().(*service) +// routesHandler is the main handler for all routes +func routesHandler(s1 *HttpService, mux Mux, ns *string) (http.Handler, error) { + s := s1.Load().(*graphjinService) // Healthcheck API mux.Handle(healthRoute, healthCheckHandler(s1)) diff --git a/serv/secrets.go b/serv/secrets.go index 28c379e4..a8c3aff7 100644 --- a/serv/secrets.go +++ b/serv/secrets.go @@ -6,15 +6,18 @@ import ( "go.uber.org/zap" ) +// SecretArgs holds the arguments for the secrets command type SecretArgs struct { KMS, KMSC, AWS, GCP, Azure, PGP string } +// SecretsCmd runs the secrets command func SecretsCmd(cmdName, fileName string, sa SecretArgs, args []string, log *zap.SugaredLogger) error { return secrets.SecretsCmd( cmdName, fileName, secrets.SecretArgs(sa), args, log) } +// InitSecrets initializes the secrets from the secrets file func initSecrets(secFile string, fs afero.Fs) (map[string]string, error) { return secrets.Init(secFile, fs) } diff --git a/serv/serv.go b/serv/serv.go index 9d090358..26870ae8 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -20,8 +20,9 @@ const ( defaultHP = "0.0.0.0:8080" ) -func initConfigWatcher(s1 *Service) { - s := s1.Load().(*service) +// Initialize the watcher for the graphjin config file +func initConfigWatcher(s1 *HttpService) { + s := s1.Load().(*graphjinService) if s.conf.Serv.Production { return } @@ -34,8 +35,9 @@ func initConfigWatcher(s1 *Service) { }() } -func initHotDeployWatcher(s1 *Service) { - s := s1.Load().(*service) +// Initialize the hot deploy watcher +func initHotDeployWatcher(s1 *HttpService) { + s := s1.Load().(*graphjinService) go func() { err := startHotDeployWatcher(s1) if err != nil { @@ -44,8 +46,9 @@ func initHotDeployWatcher(s1 *Service) { }() } -func startHTTP(s1 *Service) { - s := s1.Load().(*service) +// Start the HTTP server +func startHTTP(s1 *HttpService) { + s := s1.Load().(*graphjinService) r := chi.NewRouter() routes, err := routesHandler(s1, r, s.namespace) @@ -125,6 +128,7 @@ func startHTTP(s1 *Service) { <-idleConnsClosed } +// Set the server header func setServerHeader(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Server", serverName) diff --git a/serv/telemetry.go b/serv/telemetry.go index 4386e6a2..58aa5d58 100644 --- a/serv/telemetry.go +++ b/serv/telemetry.go @@ -10,9 +10,10 @@ import ( semconv "go.opentelemetry.io/otel/semconv/v1.24.0" ) +// InitTelemetry initializes the OpenTelemetry SDK with the given exporter and service name. func InitTelemetry( - c context.Context, - exp trace.SpanExporter, + context context.Context, + exporter trace.SpanExporter, serviceName, serviceInstanceID string, ) error { r1 := resource.NewWithAttributes( @@ -27,7 +28,7 @@ func InitTelemetry( } provider := trace.NewTracerProvider( - trace.WithBatcher(exp), + trace.WithBatcher(exporter), trace.WithResource(r2), trace.WithSampler(trace.AlwaysSample()), ) diff --git a/serv/webui.go b/serv/webui.go index 38b27055..a991b4b4 100644 --- a/serv/webui.go +++ b/serv/webui.go @@ -10,6 +10,7 @@ import ( //go:embed web/build var webBuild embed.FS +// webuiHandler serves the web UI func webuiHandler(routePrefix string, gqlEndpoint string) http.Handler { webRoot, _ := fs.Sub(webBuild, "web/build") fs := http.FileServer(http.FS(webRoot)) diff --git a/serv/ws.go b/serv/ws.go index 2a6b5b74..f99a075c 100644 --- a/serv/ws.go +++ b/serv/ws.go @@ -74,7 +74,8 @@ type wsState struct { done chan bool } -func (s *service) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.HandlerFunc) { +// apiV1Ws handles the websocket connection +func (s *graphjinService) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.HandlerFunc) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { renderErr(w, err) @@ -127,7 +128,8 @@ type authHeaders struct { UserID interface{} `json:"X-User-ID"` } -func (s *service) subSwitch(wc *wsConn, req wsReq) (err error) { +// subSwitch handles the websocket message types +func (s *graphjinService) subSwitch(wc *wsConn, req wsReq) (err error) { switch req.Type { case "connection_init": if err = setHeaders(req, wc.r); err != nil { @@ -198,7 +200,8 @@ func (s *service) subSwitch(wc *wsConn, req wsReq) (err error) { return } -func (s *service) waitForData(wc *wsConn, st *wsState, useNext bool) { +// waitForData waits for data from the subscription +func (s *graphjinService) waitForData(wc *wsConn, st *wsState, useNext bool) { var buf bytes.Buffer var ptype string @@ -247,6 +250,7 @@ func (s *service) waitForData(wc *wsConn, st *wsState, useNext bool) { } } +// setHeaders sets the headers from the payload func setHeaders(req wsReq, r *http.Request) (err error) { if len(req.Payload) == 0 { return @@ -266,6 +270,7 @@ func setHeaders(req wsReq, r *http.Request) (err error) { return } +// sendError sends an error message to the client func sendError(wc *wsConn, id string, cerr error) (err error) { m := wsRes{ID: id, Type: "error"} m.Payload.Errors = []core.Error{{Message: cerr.Error()}} diff --git a/tests/core_test.go b/tests/core_test.go index cd9b3e6d..b1c8f2a3 100644 --- a/tests/core_test.go +++ b/tests/core_test.go @@ -65,7 +65,7 @@ func TestAPQ(t *testing.T) { return } - _, err = gj.GraphQL(context.Background(), gql, nil, &core.ReqConfig{ + _, err = gj.GraphQL(context.Background(), gql, nil, &core.RequestConfig{ APQKey: "getProducts", }) if err != nil { @@ -73,7 +73,7 @@ func TestAPQ(t *testing.T) { return } - res, err := gj.GraphQL(context.Background(), "", nil, &core.ReqConfig{ + res, err := gj.GraphQL(context.Background(), "", nil, &core.RequestConfig{ APQKey: "getProducts", }) if err != nil { @@ -199,7 +199,7 @@ func TestAllowListWithNamespace(t *testing.T) { return } - var rc core.ReqConfig + var rc core.RequestConfig rc.SetNamespace("api") _, err = gj2.GraphQL(context.Background(), gql2, nil, &rc)