首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >golang源码分析:mockey

golang源码分析:mockey

作者头像
golangLeetcode
发布2026-03-18 18:12:00
发布2026-03-18 18:12:00
800
举报

前面介绍了gomonkey和goconvey两个测试工具,字节在它俩基础上封装了一个更直观的工具:github.com/bytedance/mockey,下面结合例子看下是如何使用的。

代码语言:javascript
复制
package main
import (
    "fmt"
    "os"
    "testing"
    . "github.com/bytedance/mockey"
    . "github.com/smartystreets/goconvey/convey"
)
//go:generate go test -v -run ^TestMockXXX$ -gcflags="all=-N -l"
func init() {
    os.Setenv("MOCKEY_CHECK_GCFLAGS", "false")
}
func Foo(in string) string {
    return in
}
type A struct{}
func (a A) Foo(in string) string { return in }
var Bar = 0
func TestMockXXX(t *testing.T) {
    PatchConvey("TestMockXXX", t, func() {
        Mock(Foo).Return("c").Build()   // mock函数
        Mock(A.Foo).Return("c").Build() // mock方法
        MockValue(&Bar).To(1)           // mock变量
        So(Foo("a"), ShouldEqual, "c")        // 断言`Foo`成功mock
        So(new(A).Foo("b"), ShouldEqual, "c") // 断言`A.Foo`成功mock
        So(Bar, ShouldEqual, 1)               // 断言`Bar`成功mock
    })
    // `PatchConvey`外自动释放mock
    fmt.Println(Foo("a"))        // a
    fmt.Println(new(A).Foo("b")) // b
    fmt.Println(Bar)             // 0
}

测试下:

代码语言:javascript
复制
% go test -gcflags="all=-l -N" -v ./test/mockey/exp1/...
=== RUN   TestMockXXX
  TestMockXXX ✔✔✔
3 total assertions
a
b
0
--- PASS: TestMockXXX (0.00s)
PASS
ok      learn/test/mockey/exp1  0.461s

可以看到它将gomonkey的一系列ApplyXXX方法改成了下面格式

代码语言:javascript
复制
 Mock(Foo).Return("c").Build() 
 Mock(A.Foo).Return("c").Build() 
代码语言:javascript
复制
MockValue(&Bar).To(1) 

并且将goconvey的Convey方法改成了PatchConvey方法。同时还支持了When方法,根据条件返回不同的返回值,这里我觉得有点鸡肋,既然Convey就是为了做隔离,不同条件,应该是独立的case,分开写就好了,为啥一定要写在一起呢?

代码语言:javascript
复制
        mock := Mock((*B).Fun).
            When(func(b *B, in string) bool { return len(in) > 10 }). // optionally the class reference can be passed as the first parameter
            Return("MOCKED!").
            Build()

接下来我们看下PatchConvey方法,它内部还是调用了goconvey的Convey方法,只是在调用之前通过defer注册了unPatch方法,这样我们就不用每次显式调用unPatch了。

代码语言:javascript
复制
func PatchConvey(items ...interface{}) {
    for i, item := range items {
        if reflect.TypeOf(item).Kind() == reflect.Func {
            items[i] = reflect.MakeFunc(reflect.TypeOf(item), func(args []reflect.Value) []reflect.Value {
                gMocker = append(gMocker, make(map[uintptr]mockerInstance))
                defer func() {
                    for _, mocker := range gMocker[len(gMocker)-1] {
                        mocker.unPatch()
                    }
                    gMocker = gMocker[:len(gMocker)-1]
                }()
                return tool.ReflectCall(reflect.ValueOf(item), args)
            }).Interface()
        }
    }
    convey.Convey(items...)
}

具体怎么实现的呢?它维护了一个全局的栈

代码语言:javascript
复制
var gMocker = make([]map[uintptr]mockerInstance, 0)

每次调用PatchConvey的时候,它通过反射改写了传入的函数,在函数之前加入一段逻辑,给当前convey添加了一个map来存储unPatch方法,并将它加入到栈结尾,当调用结束后,先执行所有的unPatch,然后栈弹出当前map,类似给函数加了个middleware。

接着看下mock函数

代码语言:javascript
复制
func Mock(target interface{}, opt ...optionFn) *MockBuilder {
    tool.AssertFunc(target)
    option := resolveOpt(opt...)
    builder := &MockBuilder{
        target:  target,
        unsafe:  option.unsafe,
        generic: option.generic,
    }
    builder.resetCondition()
    return builder
}
代码语言:javascript
复制
type MockBuilder struct {
    target          interface{}      // mock target
    proxyCaller     interface{}      // origin function caller hook
    conditions      []*mockCondition // mock conditions
    filterGoroutine FilterGoroutineType
    gId             int64
    unsafe          bool
    generic         bool
}
代码语言:javascript
复制
func (builder *MockBuilder) resetCondition() *MockBuilder {
    builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed
    return builder
}

然后是return方法

代码语言:javascript
复制
func (builder *MockBuilder) Return(results ...interface{}) *MockBuilder {
    builder.lastCondition().SetReturn(results...)
    return builder
}

获取栈顶元素

代码语言:javascript
复制
func (builder *MockBuilder) lastCondition() *mockCondition {
    cond := builder.conditions[len(builder.conditions)-1]
    if cond.Complete() {
        cond = builder.newCondition()
        builder.conditions = append(builder.conditions, cond)
    }
    return cond
}

然后把返回值设置进去

代码语言:javascript
复制
func (m *mockCondition) SetReturn(results ...interface{}) {
    tool.Assert(m.hook == nil, "re-set builder hook")
    m.SetReturnForce(results...)
}
代码语言:javascript
复制
func (m *mockCondition) SetReturnForce(results ...interface{}) {
    getResult := func() []interface{} { return results }
    if len(results) == 1 {
        seq, ok := results[0].(SequenceOpt)
        if ok {
            getResult = seq.GetNext
        }
    }
    hookType := m.builder.hookType()
    m.hook = reflect.MakeFunc(hookType, func(_ []reflect.Value) []reflect.Value {
        results := getResult()
        tool.CheckReturnType(m.builder.target, results...)
        valueResults := make([]reflect.Value, 0)
        for i, result := range results {
            rValue := reflect.Zero(hookType.Out(i))
            if result != nil {
                rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
            }
            valueResults = append(valueResults, rValue)
        }
        return valueResults
    }).Interface()
}

其核心函数是这几行

代码语言:javascript
复制
            rValue := reflect.Zero(hookType.Out(i))
            if result != nil {
                rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
            }
            valueResults = append(valueResults, rValue)

然后我们看下build函数

代码语言:javascript
复制
func (builder *MockBuilder) Build() *Mocker {
    mocker := Mocker{target: reflect.ValueOf(builder.target), builder: builder}
    mocker.buildHook()
    mocker.Patch()
    return &mocker
}
代码语言:javascript
复制
func (mocker *Mocker) buildHook() {
    proxySetter := mocker.buildProxy()
    originExec := func(args []reflect.Value) []reflect.Value {
        return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args)
    }
    match := []func(args []reflect.Value) bool{}
    exec := []func(args []reflect.Value) []reflect.Value{}
    for i := range mocker.builder.conditions {
        condition := mocker.builder.conditions[i]
        if condition.when == nil {
            // when condition is not set, just go into hook exec
            match = append(match, func(args []reflect.Value) bool { return true })
        } else {
            match = append(match, func(args []reflect.Value) bool {
                return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool()
            })
        }
        if condition.hook == nil {
            // hook condition is not set, just go into original exec
            exec = append(exec, originExec)
        } else {
            exec = append(exec, func(args []reflect.Value) []reflect.Value {
                mocker.mock()
                return tool.ReflectCall(reflect.ValueOf(condition.hook), args)
            })
        }
    }
    mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value {
        proxySetter(args) // 设置origin调用proxy
        mocker.access()
        switch mocker.builder.filterGoroutine {
        case Disable:
            break
        case Include:
            if tool.GetGoroutineID() != mocker.builder.gId {
                return originExec(args)
            }
        case Exclude:
            if tool.GetGoroutineID() == mocker.builder.gId {
                return originExec(args)
            }
        }
        for i, matchFn := range match {
            execFn := exec[i]
            if matchFn(args) {
                return execFn(args)
            }
        }
        return originExec(args)
    })
    mocker.hook = mockerHook
}

完成函数的替换

代码语言:javascript
复制
func (mocker *Mocker) Patch() *Mocker {
    mocker.lock.Lock()
    defer mocker.lock.Unlock()
    if mocker.isPatched {
        return mocker
    }
    mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy), mocker.builder.unsafe, mocker.builder.generic)
    mocker.isPatched = true
    addToGlobal(mocker)
    mocker.outerCaller = tool.OuterCaller()
    return mocker
}

Patch的时候,会把当前patch存储到全局

代码语言:javascript
复制
func addToGlobal(mocker mockerInstance) {
    tool.DebugPrintf("%v added\n", mocker.key())
    last, ok := gMocker[len(gMocker)-1][mocker.key()]
    if ok {
        tool.Assert(!ok, "re-mock %v, previous mock at: %v", mocker.name(), last.caller())
    }
    gMocker[len(gMocker)-1][mocker.key()] = mocker
}

对应的UnPatch会从全局map中删除,它在PatchConvey里注册的defer上执行。

代码语言:javascript
复制
func (mocker *Mocker) UnPatch() *Mocker {
    mocker.lock.Lock()
    defer mocker.lock.Unlock()
    if !mocker.isPatched {
        return mocker
    }
    mocker.patch.Unpatch()
    mocker.isPatched = false
    removeFromGlobal(mocker)
    atomic.StoreInt64(&mocker.times, 0)
    atomic.StoreInt64(&mocker.mockTimes, 0)
    return mocker
}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-06-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档