Skip to content

Commit 0726845

Browse files
authored
Merge pull request #78 from AVOlili/avolili
新增支持指定返回值和method不用传入receiver
2 parents 7052c4a + d6b60c4 commit 0726845

File tree

5 files changed

+283
-2
lines changed

5 files changed

+283
-2
lines changed

patch.go

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ func ApplyMethod(target reflect.Type, methodName string, double interface{}) *Pa
2828
return create().ApplyMethod(target, methodName, double)
2929
}
3030

31+
func ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
32+
return create().ApplyMethodFunc(target, methodName, doubleFunc)
33+
}
34+
3135
func ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
3236
return create().ApplyPrivateMethod(target, methodName, double)
3337
}
@@ -52,6 +56,18 @@ func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
5256
return create().ApplyFuncVarSeq(target, outputs)
5357
}
5458

59+
func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches {
60+
return create().ApplyFuncReturn(target, output...)
61+
}
62+
63+
func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches {
64+
return create().ApplyMethodReturn(target, methodName, output...)
65+
}
66+
67+
func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
68+
return create().ApplyFuncVarReturn(target, output...)
69+
}
70+
5571
func create() *Patches {
5672
return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
5773
}
@@ -75,6 +91,15 @@ func (this *Patches) ApplyMethod(target reflect.Type, methodName string, double
7591
return this.ApplyCore(m.Func, d)
7692
}
7793

94+
func (this *Patches) ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
95+
m, ok := target.MethodByName(methodName)
96+
if !ok {
97+
panic("retrieve method by name failed")
98+
}
99+
d := funcToMethod(m.Type, doubleFunc)
100+
return this.ApplyCore(m.Func, d)
101+
}
102+
78103
func (this *Patches) ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
79104
m, ok := creflect.MethodByName(target, methodName)
80105
if !ok {
@@ -136,6 +161,40 @@ func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *
136161
return this.ApplyGlobalVar(target, double)
137162
}
138163

164+
func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches {
165+
funcType := reflect.TypeOf(target)
166+
t := reflect.ValueOf(target)
167+
outputs := []OutputCell{{Values: returns, Times: -1}}
168+
d := getDoubleFunc(funcType, outputs)
169+
return this.ApplyCore(t, d)
170+
}
171+
172+
func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches {
173+
m, ok := reflect.TypeOf(target).MethodByName(methodName)
174+
if !ok {
175+
panic("retrieve method by name failed")
176+
}
177+
178+
outputs := []OutputCell{{Values: returns, Times: -1}}
179+
d := getDoubleFunc(m.Type, outputs)
180+
return this.ApplyCore(m.Func, d)
181+
}
182+
183+
func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches {
184+
t := reflect.ValueOf(target)
185+
if t.Type().Kind() != reflect.Ptr {
186+
panic("target is not a pointer")
187+
}
188+
if t.Elem().Kind() != reflect.Func {
189+
panic("target is not a func")
190+
}
191+
192+
funcType := reflect.TypeOf(target).Elem()
193+
outputs := []OutputCell{{Values: returns, Times: -1}}
194+
double := getDoubleFunc(funcType, outputs).Interface()
195+
return this.ApplyGlobalVar(target, double)
196+
}
197+
139198
func (this *Patches) Reset() {
140199
for target, bytes := range this.originals {
141200
modifyBinary(target, bytes)
@@ -203,8 +262,14 @@ func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
203262
funcType.NumOut(), len(outputs[0].Values)))
204263
}
205264

265+
needReturn := false
206266
slice := make([]Params, 0)
207267
for _, output := range outputs {
268+
if output.Times == -1 {
269+
needReturn = true
270+
slice = []Params{output.Values}
271+
break
272+
}
208273
t := 0
209274
if output.Times <= 1 {
210275
t = 1
@@ -217,9 +282,12 @@ func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
217282
}
218283

219284
i := 0
220-
len := len(slice)
285+
lenOutputs := len(slice)
221286
return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
222-
if i < len {
287+
if needReturn {
288+
return GetResultValues(funcType, slice[0]...)
289+
}
290+
if i < lenOutputs {
223291
i++
224292
return GetResultValues(funcType, slice[i-1]...)
225293
}
@@ -259,3 +327,14 @@ func entryAddress(p uintptr, l int) []byte {
259327
func pageStart(ptr uintptr) uintptr {
260328
return ptr & ^(uintptr(syscall.Getpagesize() - 1))
261329
}
330+
331+
func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value {
332+
rf := reflect.TypeOf(doubleFunc)
333+
if rf.Kind() != reflect.Func {
334+
panic("doubleFunc is not a func")
335+
}
336+
vf := reflect.ValueOf(doubleFunc)
337+
return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
338+
return vf.Call(in[1:])
339+
})
340+
}

test/apply_func_return_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package test
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/agiledragon/gomonkey/v2"
7+
"github.com/agiledragon/gomonkey/v2/test/fake"
8+
. "github.com/smartystreets/goconvey/convey"
9+
)
10+
11+
/*
12+
compare with apply_func_seq_test.go
13+
*/
14+
func TestApplyFuncReturn(t *testing.T) {
15+
Convey("TestApplyFuncReturn", t, func() {
16+
17+
Convey("declares the values to be returned", func() {
18+
info1 := "hello cpp"
19+
20+
patches := ApplyFuncReturn(fake.ReadLeaf, info1, nil)
21+
defer patches.Reset()
22+
23+
for i := 0; i < 10; i++ {
24+
output, err := fake.ReadLeaf("")
25+
So(err, ShouldEqual, nil)
26+
So(output, ShouldEqual, info1)
27+
}
28+
29+
patches.Reset() // if not reset will occur:patch has been existed
30+
info2 := "hello golang"
31+
patches.ApplyFuncReturn(fake.ReadLeaf, info2, nil)
32+
for i := 0; i < 10; i++ {
33+
output, err := fake.ReadLeaf("")
34+
So(err, ShouldEqual, nil)
35+
So(output, ShouldEqual, info2)
36+
}
37+
})
38+
})
39+
}

test/apply_func_var_return_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package test
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/agiledragon/gomonkey/v2"
7+
"github.com/agiledragon/gomonkey/v2/test/fake"
8+
. "github.com/smartystreets/goconvey/convey"
9+
)
10+
11+
/*
12+
compare with apply_func_var_seq_test.go
13+
*/
14+
func TestApplyFuncVarReturn(t *testing.T) {
15+
Convey("TestApplyFuncVarReturn", t, func() {
16+
17+
Convey("declares the values to be returned", func() {
18+
info1 := "hello cpp"
19+
20+
patches := ApplyFuncVarReturn(&fake.Marshal, []byte(info1), nil)
21+
defer patches.Reset()
22+
for i := 0; i < 10; i++ {
23+
bytes, err := fake.Marshal("")
24+
So(err, ShouldEqual, nil)
25+
So(string(bytes), ShouldEqual, info1)
26+
}
27+
28+
info2 := "hello golang"
29+
patches.ApplyFuncVarReturn(&fake.Marshal, []byte(info2), nil)
30+
for i := 0; i < 10; i++ {
31+
bytes, err := fake.Marshal("")
32+
So(err, ShouldEqual, nil)
33+
So(string(bytes), ShouldEqual, info2)
34+
}
35+
})
36+
37+
})
38+
}

test/apply_method_func_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package test
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
. "github.com/agiledragon/gomonkey/v2"
8+
"github.com/agiledragon/gomonkey/v2/test/fake"
9+
. "github.com/smartystreets/goconvey/convey"
10+
)
11+
12+
/*
13+
compare with apply_method_test.go, no need pass receiver
14+
*/
15+
16+
func TestApplyMethodFunc(t *testing.T) {
17+
slice := fake.NewSlice()
18+
var s *fake.Slice
19+
Convey("TestApplyMethodFunc", t, func() {
20+
Convey("for succ", func() {
21+
err := slice.Add(1)
22+
So(err, ShouldEqual, nil)
23+
patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error {
24+
return nil
25+
})
26+
defer patches.Reset()
27+
err = slice.Add(1)
28+
So(err, ShouldEqual, nil)
29+
err = slice.Remove(1)
30+
So(err, ShouldEqual, nil)
31+
So(len(slice), ShouldEqual, 0)
32+
})
33+
34+
Convey("for already exist", func() {
35+
err := slice.Add(2)
36+
So(err, ShouldEqual, nil)
37+
patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error {
38+
return fake.ErrElemExsit
39+
})
40+
defer patches.Reset()
41+
err = slice.Add(1)
42+
So(err, ShouldEqual, fake.ErrElemExsit)
43+
err = slice.Remove(2)
44+
So(err, ShouldEqual, nil)
45+
So(len(slice), ShouldEqual, 0)
46+
})
47+
48+
Convey("two methods", func() {
49+
err := slice.Add(3)
50+
So(err, ShouldEqual, nil)
51+
defer slice.Remove(3)
52+
patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error {
53+
return fake.ErrElemExsit
54+
})
55+
defer patches.Reset()
56+
patches.ApplyMethodFunc(reflect.TypeOf(s), "Remove", func(_ int) error {
57+
return fake.ErrElemNotExsit
58+
})
59+
err = slice.Add(2)
60+
So(err, ShouldEqual, fake.ErrElemExsit)
61+
err = slice.Remove(1)
62+
So(err, ShouldEqual, fake.ErrElemNotExsit)
63+
So(len(slice), ShouldEqual, 1)
64+
So(slice[0], ShouldEqual, 3)
65+
})
66+
67+
Convey("one func and one method", func() {
68+
err := slice.Add(4)
69+
So(err, ShouldEqual, nil)
70+
defer slice.Remove(4)
71+
patches := ApplyFunc(fake.Exec, func(_ string, _ ...string) (string, error) {
72+
return outputExpect, nil
73+
})
74+
defer patches.Reset()
75+
patches.ApplyMethodFunc(reflect.TypeOf(s), "Remove", func(_ int) error {
76+
return fake.ErrElemNotExsit
77+
})
78+
output, err := fake.Exec("", "")
79+
So(err, ShouldEqual, nil)
80+
So(output, ShouldEqual, outputExpect)
81+
err = slice.Remove(1)
82+
So(err, ShouldEqual, fake.ErrElemNotExsit)
83+
So(len(slice), ShouldEqual, 1)
84+
So(slice[0], ShouldEqual, 4)
85+
})
86+
})
87+
}

test/apply_method_return_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package test
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/agiledragon/gomonkey/v2"
7+
"github.com/agiledragon/gomonkey/v2/test/fake"
8+
. "github.com/smartystreets/goconvey/convey"
9+
)
10+
11+
/*
12+
compare with apply_method_seq_test.go
13+
*/
14+
15+
func TestApplyMethodReturn(t *testing.T) {
16+
e := &fake.Etcd{}
17+
Convey("TestApplyMethodReturn", t, func() {
18+
Convey("declares the values to be returned", func() {
19+
info1 := "hello cpp"
20+
patches := ApplyMethodReturn(e, "Retrieve", info1, nil)
21+
defer patches.Reset()
22+
for i := 0; i < 10; i++ {
23+
output1, err1 := e.Retrieve("")
24+
So(err1, ShouldEqual, nil)
25+
So(output1, ShouldEqual, info1)
26+
}
27+
28+
patches.Reset() // if not reset will occur:patch has been existed
29+
info2 := "hello golang"
30+
patches.ApplyMethodReturn(e, "Retrieve", info2, nil)
31+
for i := 0; i < 10; i++ {
32+
output2, err2 := e.Retrieve("")
33+
So(err2, ShouldEqual, nil)
34+
So(output2, ShouldEqual, info2)
35+
}
36+
})
37+
})
38+
}

0 commit comments

Comments
 (0)