unit-testing之如何使用 gomock (或类似的)来模拟/验证对数据库的调用

jiqing9006 阅读:222 2025-06-02 22:19:02 评论:0

转到这里,使用 gorm到或/映射到数据库(PSQL)。
我有以下代码:

package dbstuff 
 
import ( 
    "errors" 
 
  "github.com/google/uuid" 
  "github.com/jinzhu/gorm" 
    _ "github.com/jinzhu/gorm/dialects/postgres" 
) 
 
type OrderPersister struct { 
        db *gorm.DB 
} 
 
func (p *OrderPersister) GetOrder(id uuid.UUID) (*Order, error) { 
        ret := &Order{} 
 
        err := p.db.Table("orders").Where("order_id = ?", id).Scan(ret).Error 
        return ret, err 
} 
我正在尝试为它编写一个单元测试,如下所示:
package dbstuff 
 
import ( 
    "testing" 
  "errors" 
 
  "github.com/stretchr/testify/assert" 
) 
 
func TestErrInternalServerError(t *testing.T) { 
 
  // given 
  id := uuid.New() 
  op := OrderPersister{} 
 
  // when 
  order, err := op.GetOrder(id) 
 
  // then 
  assert.NotNil(t, order) 
  assert.NotNil(t, err) 
 
} 
当我运行它时,我得到无效的内存地址或 nil 指针取消引用错误,因为我没有实例化设置 *gorm.DB在我的 OrderPersister实例。有没有一种简单的方法来模拟/ stub ,以便我的测试确认我们尝试查询 orders表并返回或/映射的结果?

请您参考如下方法:

我将使用 testify包为您的代码编写单元测试。而不是使用具体类型 *gorm.DB ,声明 OrderPersister 的 DB 接口(interface)结构。由于我们不能在 Go 中模拟具体类型及其方法。我们需要创建一个抽象层 - interface .63622995/db/db.go :

package db 
 
type OrmDBWithError struct { 
    OrmDB 
    Error error 
} 
 
type OrmDB interface { 
    Table(name string) OrmDB 
    Where(query interface{}, args ...interface{}) OrmDB 
    Scan(dest interface{}) *OrmDBWithError 
} 
63622995/main.go :
package main 
 
import ( 
    "github.com/google/uuid" 
    _ "github.com/jinzhu/gorm/dialects/postgres" 
    "github.com/mrdulin/golang/src/stackoverflow/63622995/db" 
) 
 
type Order struct { 
    order_id string 
} 
 
type OrderPersister struct { 
    DB db.OrmDB 
    //DB *gorm.DB 
} 
 
func (p *OrderPersister) GetOrder(id uuid.UUID) (*Order, error) { 
    ret := &Order{} 
 
    err := p.DB.Table("orders").Where("order_id = ?", id).Scan(ret).Error 
    return ret, err 
} 
为实现 OrmDB 的 db 创建了模拟对象界面。然后,您可以创建此模拟 DB 对象并将其传递给 OrderPersister结构。 63622995/mocks/db.go :
package mocks 
 
import ( 
    "github.com/mrdulin/golang/src/stackoverflow/63622995/db" 
    "github.com/stretchr/testify/mock" 
) 
 
type MockedOrmDB struct { 
    mock.Mock 
} 
 
func (s *MockedOrmDB) Table(name string) db.OrmDB { 
    args := s.Called(name) 
    return args.Get(0).(db.OrmDB) 
} 
 
func (s *MockedOrmDB) Where(query interface{}, args ...interface{}) db.OrmDB { 
    arguments := s.Called(query, args) 
    return arguments.Get(0).(db.OrmDB) 
} 
 
func (s *MockedOrmDB) Scan(dest interface{}) *db.OrmDBWithError { 
    args := s.Called(dest) 
    return args.Get(0).(*db.OrmDBWithError) 
} 
63622995/main_test.go :
package main 
 
import ( 
    "testing" 
 
    "github.com/google/uuid" 
    "github.com/mrdulin/golang/src/stackoverflow/63622995/db" 
    "github.com/mrdulin/golang/src/stackoverflow/63622995/mocks" 
    "github.com/stretchr/testify/assert" 
    "github.com/stretchr/testify/mock" 
) 
 
func TestOrderPersister_GetOrder(t *testing.T) { 
    assert := assert.New(t) 
    t.Run("should get order", func(t *testing.T) { 
        testDb := new(mocks.MockedOrmDB) 
        id := uuid.New() 
        testDb. 
            On("Table", "orders"). 
            Return(testDb). 
            On("Where", "order_id = ?", mock.Anything). 
            Return(testDb). 
            On("Scan", mock.Anything).Run(func(args mock.Arguments) { 
            ret := args.Get(0).(*Order) 
            ret.order_id = "123" 
        }). 
            Return(&db.OrmDBWithError{Error: nil}) 
        op := OrderPersister{DB: testDb} 
        got, err := op.GetOrder(id) 
        testDb.AssertExpectations(t) 
        assert.Nil(err) 
        assert.Equal(Order{order_id: "123"}, *got) 
    }) 
 
    t.Run("should return error", func(t *testing.T) { 
        testDb := new(mocks.MockedOrmDB) 
        id := uuid.New() 
        testDb. 
            On("Table", "orders"). 
            Return(testDb). 
            On("Where", "order_id = ?", mock.Anything). 
            Return(testDb). 
            On("Scan", mock.Anything). 
            Return(&db.OrmDBWithError{Error: errors.New("network")}) 
        op := OrderPersister{DB: testDb} 
        got, err := op.GetOrder(id) 
        testDb.AssertExpectations(t) 
        assert.Equal(Order{}, *got) 
        assert.Equal(err.Error(), "network") 
    }) 
} 
单元测试结果:
=== RUN   TestOrderPersister_GetOrder 
=== RUN   TestOrderPersister_GetOrder/should_get_order 
    TestOrderPersister_GetOrder/should_get_order: main_test.go:32: PASS:    Table(string) 
    TestOrderPersister_GetOrder/should_get_order: main_test.go:32: PASS:    Where(string,string) 
    TestOrderPersister_GetOrder/should_get_order: main_test.go:32: PASS:    Scan(string) 
=== RUN   TestOrderPersister_GetOrder/should_return_error 
    TestOrderPersister_GetOrder/should_return_error: main_test.go:49: PASS: Table(string) 
    TestOrderPersister_GetOrder/should_return_error: main_test.go:49: PASS: Where(string,string) 
    TestOrderPersister_GetOrder/should_return_error: main_test.go:49: PASS: Scan(string) 
--- PASS: TestOrderPersister_GetOrder (0.00s) 
    --- PASS: TestOrderPersister_GetOrder/should_get_order (0.00s) 
    --- PASS: TestOrderPersister_GetOrder/should_return_error (0.00s) 
PASS 
 
Process finished with exit code 0 
覆盖报告:


标签:数据库
声明

1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。

关注我们

一个IT知识分享的公众号