
前面介绍了gomonkey和goconvey两个测试工具,字节在它俩基础上封装了一个更直观的工具:github.com/bytedance/mockey,下面结合例子看下是如何使用的。
package main
import (
"fmt"
"os"
"testing"
. "github.com/bytedance/mockey"
. "github.com/smartystreets/goconvey/convey"
)
//go:generate go test -v -run ^TestMockXXX$ -gcflags="all=-N -l"
func init() {
os.Setenv("MOCKEY_CHECK_GCFLAGS", "false")
}
func Foo(in string) string {
return in
}
type A struct{}
func (a A) Foo(in string) string { return in }
var Bar = 0
func TestMockXXX(t *testing.T) {
PatchConvey("TestMockXXX", t, func() {
Mock(Foo).Return("c").Build() // mock函数
Mock(A.Foo).Return("c").Build() // mock方法
MockValue(&Bar).To(1) // mock变量
So(Foo("a"), ShouldEqual, "c") // 断言`Foo`成功mock
So(new(A).Foo("b"), ShouldEqual, "c") // 断言`A.Foo`成功mock
So(Bar, ShouldEqual, 1) // 断言`Bar`成功mock
})
// `PatchConvey`外自动释放mock
fmt.Println(Foo("a")) // a
fmt.Println(new(A).Foo("b")) // b
fmt.Println(Bar) // 0
}测试下:
% go test -gcflags="all=-l -N" -v ./test/mockey/exp1/...
=== RUN TestMockXXX
TestMockXXX ✔✔✔
3 total assertions
a
b
0
--- PASS: TestMockXXX (0.00s)
PASS
ok learn/test/mockey/exp1 0.461s可以看到它将gomonkey的一系列ApplyXXX方法改成了下面格式
Mock(Foo).Return("c").Build()
Mock(A.Foo).Return("c").Build() MockValue(&Bar).To(1) 并且将goconvey的Convey方法改成了PatchConvey方法。同时还支持了When方法,根据条件返回不同的返回值,这里我觉得有点鸡肋,既然Convey就是为了做隔离,不同条件,应该是独立的case,分开写就好了,为啥一定要写在一起呢?
mock := Mock((*B).Fun).
When(func(b *B, in string) bool { return len(in) > 10 }). // optionally the class reference can be passed as the first parameter
Return("MOCKED!").
Build()接下来我们看下PatchConvey方法,它内部还是调用了goconvey的Convey方法,只是在调用之前通过defer注册了unPatch方法,这样我们就不用每次显式调用unPatch了。
func PatchConvey(items ...interface{}) {
for i, item := range items {
if reflect.TypeOf(item).Kind() == reflect.Func {
items[i] = reflect.MakeFunc(reflect.TypeOf(item), func(args []reflect.Value) []reflect.Value {
gMocker = append(gMocker, make(map[uintptr]mockerInstance))
defer func() {
for _, mocker := range gMocker[len(gMocker)-1] {
mocker.unPatch()
}
gMocker = gMocker[:len(gMocker)-1]
}()
return tool.ReflectCall(reflect.ValueOf(item), args)
}).Interface()
}
}
convey.Convey(items...)
}具体怎么实现的呢?它维护了一个全局的栈
var gMocker = make([]map[uintptr]mockerInstance, 0)每次调用PatchConvey的时候,它通过反射改写了传入的函数,在函数之前加入一段逻辑,给当前convey添加了一个map来存储unPatch方法,并将它加入到栈结尾,当调用结束后,先执行所有的unPatch,然后栈弹出当前map,类似给函数加了个middleware。
接着看下mock函数
func Mock(target interface{}, opt ...optionFn) *MockBuilder {
tool.AssertFunc(target)
option := resolveOpt(opt...)
builder := &MockBuilder{
target: target,
unsafe: option.unsafe,
generic: option.generic,
}
builder.resetCondition()
return builder
}type MockBuilder struct {
target interface{} // mock target
proxyCaller interface{} // origin function caller hook
conditions []*mockCondition // mock conditions
filterGoroutine FilterGoroutineType
gId int64
unsafe bool
generic bool
}func (builder *MockBuilder) resetCondition() *MockBuilder {
builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed
return builder
}然后是return方法
func (builder *MockBuilder) Return(results ...interface{}) *MockBuilder {
builder.lastCondition().SetReturn(results...)
return builder
}获取栈顶元素
func (builder *MockBuilder) lastCondition() *mockCondition {
cond := builder.conditions[len(builder.conditions)-1]
if cond.Complete() {
cond = builder.newCondition()
builder.conditions = append(builder.conditions, cond)
}
return cond
}然后把返回值设置进去
func (m *mockCondition) SetReturn(results ...interface{}) {
tool.Assert(m.hook == nil, "re-set builder hook")
m.SetReturnForce(results...)
}func (m *mockCondition) SetReturnForce(results ...interface{}) {
getResult := func() []interface{} { return results }
if len(results) == 1 {
seq, ok := results[0].(SequenceOpt)
if ok {
getResult = seq.GetNext
}
}
hookType := m.builder.hookType()
m.hook = reflect.MakeFunc(hookType, func(_ []reflect.Value) []reflect.Value {
results := getResult()
tool.CheckReturnType(m.builder.target, results...)
valueResults := make([]reflect.Value, 0)
for i, result := range results {
rValue := reflect.Zero(hookType.Out(i))
if result != nil {
rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
}
valueResults = append(valueResults, rValue)
}
return valueResults
}).Interface()
}其核心函数是这几行
rValue := reflect.Zero(hookType.Out(i))
if result != nil {
rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
}
valueResults = append(valueResults, rValue)然后我们看下build函数
func (builder *MockBuilder) Build() *Mocker {
mocker := Mocker{target: reflect.ValueOf(builder.target), builder: builder}
mocker.buildHook()
mocker.Patch()
return &mocker
}func (mocker *Mocker) buildHook() {
proxySetter := mocker.buildProxy()
originExec := func(args []reflect.Value) []reflect.Value {
return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args)
}
match := []func(args []reflect.Value) bool{}
exec := []func(args []reflect.Value) []reflect.Value{}
for i := range mocker.builder.conditions {
condition := mocker.builder.conditions[i]
if condition.when == nil {
// when condition is not set, just go into hook exec
match = append(match, func(args []reflect.Value) bool { return true })
} else {
match = append(match, func(args []reflect.Value) bool {
return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool()
})
}
if condition.hook == nil {
// hook condition is not set, just go into original exec
exec = append(exec, originExec)
} else {
exec = append(exec, func(args []reflect.Value) []reflect.Value {
mocker.mock()
return tool.ReflectCall(reflect.ValueOf(condition.hook), args)
})
}
}
mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value {
proxySetter(args) // 设置origin调用proxy
mocker.access()
switch mocker.builder.filterGoroutine {
case Disable:
break
case Include:
if tool.GetGoroutineID() != mocker.builder.gId {
return originExec(args)
}
case Exclude:
if tool.GetGoroutineID() == mocker.builder.gId {
return originExec(args)
}
}
for i, matchFn := range match {
execFn := exec[i]
if matchFn(args) {
return execFn(args)
}
}
return originExec(args)
})
mocker.hook = mockerHook
}完成函数的替换
func (mocker *Mocker) Patch() *Mocker {
mocker.lock.Lock()
defer mocker.lock.Unlock()
if mocker.isPatched {
return mocker
}
mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy), mocker.builder.unsafe, mocker.builder.generic)
mocker.isPatched = true
addToGlobal(mocker)
mocker.outerCaller = tool.OuterCaller()
return mocker
}Patch的时候,会把当前patch存储到全局
func addToGlobal(mocker mockerInstance) {
tool.DebugPrintf("%v added\n", mocker.key())
last, ok := gMocker[len(gMocker)-1][mocker.key()]
if ok {
tool.Assert(!ok, "re-mock %v, previous mock at: %v", mocker.name(), last.caller())
}
gMocker[len(gMocker)-1][mocker.key()] = mocker
}对应的UnPatch会从全局map中删除,它在PatchConvey里注册的defer上执行。
func (mocker *Mocker) UnPatch() *Mocker {
mocker.lock.Lock()
defer mocker.lock.Unlock()
if !mocker.isPatched {
return mocker
}
mocker.patch.Unpatch()
mocker.isPatched = false
removeFromGlobal(mocker)
atomic.StoreInt64(&mocker.times, 0)
atomic.StoreInt64(&mocker.mockTimes, 0)
return mocker
}本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!