Skip to content

Commit

Permalink
utility function to get module name (#783)
Browse files Browse the repository at this point in the history
## Description
This Pull Request provides a utility function to get the module's name.
  • Loading branch information
fabianburth authored May 29, 2024
1 parent 509ba89 commit ba23b9f
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 80 deletions.
50 changes: 14 additions & 36 deletions pkg/env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package env
import (
"bytes"
"fmt"
"runtime"
"runtime/debug"
"strings"

Expand All @@ -25,7 +24,9 @@ import (
"github.com/open-component-model/ocm/pkg/contexts/datacontext/attrs/vfsattr"
"github.com/open-component-model/ocm/pkg/contexts/oci"
ocm "github.com/open-component-model/ocm/pkg/contexts/ocm/cpi"
"github.com/open-component-model/ocm/pkg/testutils"
"github.com/open-component-model/ocm/pkg/utils"
"github.com/open-component-model/ocm/pkg/utils/pkgutils"
)

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -232,23 +233,13 @@ func ModifiableTestData(paths ...string) tdOpt {
}

func projectTestData(modifiable bool, source string, dest ...string) Option {
path := "."
for count := 0; count < 20; count++ {
if ok, err := vfs.FileExists(osfs.OsFs, filepath.Join(path, "go.mod")); err != nil || ok {
if err != nil {
panic(err)
}
path = filepath.Join(path, source)
break
}
if count == 19 {
panic("could not find go.mod (within 20 steps)")
}

path = filepath.Join(path, "..")
pathToRoot, err := testutils.GetRelativePathToProjectRoot()
if err != nil {
panic(err)
}
pathToTestdata := filepath.Join(pathToRoot, source)

return testData(modifiable, path, general.OptionalDefaulted("/testdata", dest...))
return testData(modifiable, pathToTestdata, general.OptionalDefaulted("/testdata", dest...))
}

func ProjectTestData(source string, dest ...string) Option {
Expand All @@ -260,29 +251,16 @@ func ModifiableProjectTestData(source string, dest ...string) Option {
}

func projectTestDataForCaller(modifiable bool, dest ...string) Option {
pc, _, _, ok := runtime.Caller(2)
if !ok {
panic("unable to find caller")
}

// Get the function details from the program counter
caller := runtime.FuncForPC(pc)
if caller == nil {
panic("unable to find caller")
packagePath, err := pkgutils.GetPackageName(2)
if err != nil {
panic(err)
}

fullFuncName := caller.Name()

// Split the name to extract the package path
// Assuming the format: "package/path.functionName"
lastSlashIndex := strings.LastIndex(fullFuncName, "/")
if lastSlashIndex == -1 {
panic("unable to find package name")
moduleName, err := testutils.GetModuleName()
if err != nil {
panic(err)
}

funcIndex := strings.Index(fullFuncName[lastSlashIndex:], ".")
packagePath := fullFuncName[:lastSlashIndex+funcIndex]
path, ok := strings.CutPrefix(packagePath, "github.com/open-component-model/ocm/")
path, ok := strings.CutPrefix(packagePath, moduleName+"/")
if !ok {
panic("unable to find package name")
}
Expand Down
78 changes: 78 additions & 0 deletions pkg/testutils/package.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package testutils

import (
"fmt"
"strings"

"github.com/mandelsoft/filepath/pkg/filepath"
"github.com/mandelsoft/goutils/general"
"github.com/mandelsoft/vfs/pkg/osfs"
"github.com/mandelsoft/vfs/pkg/vfs"
"golang.org/x/mod/modfile"

"github.com/open-component-model/ocm/pkg/utils/pkgutils"
)

const GO_MOD = "go.mod"

func GetPackagePathFromProjectRoot(i ...interface{}) (string, error) {
pkg, err := pkgutils.GetPackageName(i...)
if err != nil {
return "", err
}
mod, err := GetModuleName()
if err != nil {
return "", err
}
path, ok := strings.CutPrefix(pkg, mod+"/")
if !ok {
return "", fmt.Errorf("prefix %q not found in %q", mod, pkg)
}
return path, nil
}

// GetModuleName returns a go modules module name by finding and parsing the go.mod file.
func GetModuleName() (string, error) {
pathToRoot, err := GetRelativePathToProjectRoot()
if err != nil {
return "", err
}
pathToGoMod := filepath.Join(pathToRoot, GO_MOD)
// Read the content of the go.mod file
data, err := vfs.ReadFile(osfs.OsFs, pathToGoMod)
if err != nil {
return "", err
}

// Parse the go.mod file
modFile, err := modfile.Parse(GO_MOD, data, nil)
if err != nil {
return "", fmt.Errorf("error parsing %s file: %w", GO_MOD, err)
}

// Print the module path
return modFile.Module.Mod.Path, nil
}

// GetRelativePathToProjectRoot calculates the relative path to a go projects root directory.
// It therefore assumes that the project root is the directory containing the go.mod file.
// The optional parameter i determines how many directories the function will step up through, attempting to find a
// go.mod file. If it cannot find a directory with a go.mod file within i iterations, the function throws an error.
func GetRelativePathToProjectRoot(i ...int) (string, error) {
iterations := general.OptionalDefaulted(20, i...)

path := "."
for count := 0; count < iterations; count++ {
if ok, err := vfs.FileExists(osfs.OsFs, filepath.Join(path, GO_MOD)); err != nil || ok {
if err != nil {
return "", fmt.Errorf("failed to check if %s exists: %w", GO_MOD, err)
}
return path, nil
}
if count == iterations {
return "", fmt.Errorf("could not find %s (within %d steps)", GO_MOD, iterations)
}
path = filepath.Join(path, "..")
}
return "", nil
}
15 changes: 15 additions & 0 deletions pkg/testutils/package_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package testutils_test

import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

me "github.com/open-component-model/ocm/pkg/testutils"
)

var _ = Describe("package tests", func() {
It("go module name", func() {
mod := me.Must(me.GetModuleName())
Expect(mod).To(Equal("github.com/open-component-model/ocm"))
})
})
44 changes: 0 additions & 44 deletions pkg/utils/package.go

This file was deleted.

102 changes: 102 additions & 0 deletions pkg/utils/pkgutils/package.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package pkgutils

import (
"fmt"
"reflect"
"runtime"
"strings"
)

// GetPackageName gets the package name for an object, a type, a function or a caller offset.
//
// Examples:
//
// GetPackageName(1)
// GetPackageName(&MyStruct{})
// GetPackageName(GetPackageName)
// GetPackageName(generics.TypeOf[MyStruct]())
func GetPackageName(i ...interface{}) (string, error) {
if len(i) == 0 {
i = []interface{}{0}
}
if t, ok := i[0].(reflect.Type); ok {
pkgpath := t.PkgPath()
if pkgpath == "" {
return "", fmt.Errorf("unable to determine package name")
}
return pkgpath, nil
}
v := reflect.ValueOf(i[0])
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Func:
return getPackageNameForFuncPC(v.Pointer())
case reflect.Struct, reflect.Chan, reflect.Map, reflect.Slice, reflect.Array:
pkgpath := v.Type().PkgPath()
if pkgpath == "" {
return "", fmt.Errorf("unable to determine package name")
}
return pkgpath, nil
default:
offset, err := CastInt(v.Interface())
if err != nil {
return "", err
}
pc, _, _, ok := runtime.Caller(offset + 1)
if !ok {
return "", fmt.Errorf("unable to find caller")
}
return getPackageNameForFuncPC(pc)
}
}

func getPackageNameForFuncPC(pc uintptr) (string, error) {
// Retrieve the function's runtime information
funcForPC := runtime.FuncForPC(pc)
if funcForPC == nil {
return "", fmt.Errorf("could not determine package name")
}
// Get the full name of the function, including the package path
fullFuncName := funcForPC.Name()

// Split the name to extract the package path
// Assuming the format: "package/path.functionName"
lastSlashIndex := strings.LastIndex(fullFuncName, "/")
if lastSlashIndex == -1 {
panic("unable to find package name")
}

funcIndex := strings.Index(fullFuncName[lastSlashIndex:], ".")
packagePath := fullFuncName[:lastSlashIndex+funcIndex]

return packagePath, nil
}

func CastInt(i interface{}) (int, error) {
switch v := i.(type) {
case int:
return v, nil
case int8:
return int(v), nil
case int16:
return int(v), nil
case int32:
return int(v), nil
case int64:
return int(v), nil
case uint:
return int(v), nil
case uint8:
return int(v), nil
case uint16:
return int(v), nil
case uint32:
return int(v), nil
case uint64:
return int(v), nil
default:
return 0, fmt.Errorf("unable to cast %T into int", i)
}
}
32 changes: 32 additions & 0 deletions pkg/utils/pkgutils/package_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package pkgutils_test

import (
"github.com/mandelsoft/goutils/generics"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
. "github.com/open-component-model/ocm/pkg/testutils"
me "github.com/open-component-model/ocm/pkg/utils/pkgutils"
"github.com/open-component-model/ocm/pkg/utils/pkgutils/testpackage"
"reflect"
)

type typ struct{}

var _ = Describe("package tests", func() {
DescribeTable("determine package type for ", func(typ interface{}) {
Expect(Must(me.GetPackageName(typ))).To(Equal(reflect.TypeOf(testpackage.MyStruct{}).PkgPath()))
},
Entry("struct", &testpackage.MyStruct{}),
Entry("array", &testpackage.MyArray{}),
Entry("list", &testpackage.MyList{}),
Entry("map", &testpackage.MyMap{}),
Entry("chan", make(testpackage.MyChan)),
Entry("func", testpackage.MyFunc),
Entry("func type", generics.TypeOf[testpackage.MyFuncType]()),
Entry("struct type", generics.TypeOf[testpackage.MyStruct]()),
)
It("determine package for caller func", func() {
Expect(Must(testpackage.MyFunc())).To(Equal(reflect.TypeOf(testpackage.MyStruct{}).PkgPath()))
Expect(Must(testpackage.MyFunc(1))).To(Equal(reflect.TypeOf(typ{}).PkgPath()))
})
})
13 changes: 13 additions & 0 deletions pkg/utils/pkgutils/suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package pkgutils_test

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestConfig(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Package Utils Test Suite")
}
21 changes: 21 additions & 0 deletions pkg/utils/pkgutils/testpackage/testtypes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package testpackage

import (
"github.com/mandelsoft/goutils/sliceutils"

"github.com/open-component-model/ocm/pkg/utils/pkgutils"
)

type (
MyStruct struct{}

MyList []int
MyArray [3]int
MyMap map[int]int
MyChan chan int
MyFuncType func()
)

func MyFunc(i ...int) (string, error) {
return pkgutils.GetPackageName(sliceutils.Convert[interface{}](i)...)
}

0 comments on commit ba23b9f

Please sign in to comment.