Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: failure to call SavePolicy within the Transaction method(#249)
Browse files Browse the repository at this point in the history
junfengxu authored and Hill1126 committed Nov 13, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent a7e4936 commit 76d3304
Showing 2 changed files with 47 additions and 3 deletions.
20 changes: 17 additions & 3 deletions adapter.go
Original file line number Diff line number Diff line change
@@ -707,18 +707,17 @@ func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) erro
a.transactionMu.Lock()
defer a.transactionMu.Unlock()
var err error
oriAdapter := a.db
// reload policy from database to sync with the transaction
defer func() {
e.SetAdapter(&Adapter{db: oriAdapter, transactionMu: a.transactionMu})
e.SetAdapter(a.Copy())
err = e.LoadPolicy()
if err != nil {
panic(err)
}
}()
copyDB := *a.db
tx := copyDB.Begin(opts...)
b := &Adapter{db: tx, transactionMu: a.transactionMu}
b := a.Copy()
// copy enforcer to set the new adapter with transaction tx
copyEnforcer := e
copyEnforcer.SetAdapter(b)
@@ -946,6 +945,21 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
return oldPolicies, tx.Commit().Error
}

func (a *Adapter) Copy() *Adapter {
oriAdapter := a.db
return &Adapter{
db: oriAdapter,
transactionMu: a.transactionMu,
driverName: a.driverName,
dataSourceName: a.dataSourceName,
databaseName: a.databaseName,
tablePrefix: a.tablePrefix,
tableName: a.tableName,
dbSpecified: a.dbSpecified,
isFiltered: a.isFiltered,
}
}

// Preview Pre-checking to avoid causing partial load success and partial failure deep
func (a *Adapter) Preview(rules *[]CasbinRule, model model.Model) error {
j := 0
30 changes: 30 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
@@ -766,3 +766,33 @@ func TestTransactionRace(t *testing.T) {
require.True(t, e.HasPolicy("jack", fmt.Sprintf("data%d", i), "write"))
}
}

func TestTransactionWithSavePolicy(t *testing.T) {
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
defer func() {
e.ClearPolicy()
err := e.SavePolicy()
if err != nil {
t.Fatalf("save policy err %v", err)
}
}()
err := e.GetAdapter().(*Adapter).Transaction(e, func(e casbin.IEnforcer) error {
_, err := e.AddPolicy("jack", "data1", "write")
if err != nil {
return err
}
_, err = e.AddPolicy("jack", "data2", "write")
if err != nil {
return err
}
err = e.SavePolicy()
if err != nil {
return err
}
return nil
})
if err != nil {
return
}
}

0 comments on commit 76d3304

Please sign in to comment.