前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >Go中的依赖注入

Go中的依赖注入

作者头像
李海彬
发布2018-10-08 14:49:05
发布2018-10-08 14:49:05
1K00
代码可运行
举报
文章被收录于专栏:Golang语言社区Golang语言社区
运行总次数:0
代码可运行

原文作者:Tin Rabzelj

I have written a small utility package to handle dependency injection in Go (it's in tinrab/kit, among other things). The goal was simplicity and for it to fit well in my current side project.

Some potentially useful features (like optional dependencies) are not yet implemented.

What follows is a possible use case when writing tests for services.

Declaring services

First step is to declare an interface, and at least one struct that implements it, for every dependency.

Database

The SQLDatabase interface will represent a database connection using Go's database/sqlpackage. The actual database is handled withing unexported struct mySQLDatabase, which uses mysql driver to connect to a MySQL server.

代码语言:javascript
代码运行次数:0
运行
复制
 1package main
 2
 3import (
 4  "database/sql"
 5
 6  _ "github.com/go-sql-driver/mysql"
 7  "github.com/tinrab/kit"
 8)
 9
10type SQLDatabase interface {
11  kit.Dependency
12  SQL() *sql.DB
13}
14
15type mySQLDatabase struct {
16  address string
17  conn    *sql.DB
18}
19
20func NewMySQLDatabase(address string) SQLDatabase {
21  return &mySQLDatabase{
22    address: address,
23  }
24}
25
26func (db *mySQLDatabase) SQL() *sql.DB {
27  return db.conn
28}

The Open and Close functions are required by kit.Dependency interface.

代码语言:javascript
代码运行次数:0
运行
复制
 1func (db *mySQLDatabase) Open() error {
 2  conn, err := sql.Open("mysql", db.address)
 3  if err != nil {
 4    return err
 5  }
 6  db.conn = conn
 7  return nil
 8}
 9
10func (db *mySQLDatabase) Close() {
11  db.conn.Close()
12}

User repository

User repository will manage users of this application.

Declare a struct to hold user's data.

代码语言:javascript
代码运行次数:0
运行
复制
1type User struct {
2  ID   uint64
3  Name string
4}

Declare UserRepository interface and mySQLUserRepository struct.

代码语言:javascript
代码运行次数:0
运行
复制
 1package main
 2
 3import "github.com/tinrab/kit"
 4
 5type UserRepository interface {
 6  kit.Dependency
 7  GetUserByID(id uint64) (*User, error)
 8}
 9
10type mySQLUserRepository struct {
11  Database SQLDatabase `inject:"database"`
12}
13
14func NewMySQLUserRepository() UserRepository {
15  return &mySQLUserRepository{}
16}
17
18func (r *mySQLUserRepository) Open() error {
19  return nil
20}
21
22func (r *mySQLUserRepository) Close() {
23}

Continue by implementing the rest of the interface.Note the inject tag on Database field. The value of database means that the dependency with a name database will be injected into this field. A value of SQLDatabase will be available after the Open function gets called.

代码语言:javascript
代码运行次数:0
运行
复制
1func (r *mySQLUserRepository) GetUserByID(id uint64) (*User, error) {
2  user := &User{}
3  err := r.Database.SQL().QueryRow("SELECT * FROM users WHERE id = ?", id).
4    Scan(&user.ID, &user.Name)
5  if err != nil {
6    return nil, err
7  }
8  return user, nil
9}

Post repository

Post repository is very similar to the user repository.

代码语言:javascript
代码运行次数:0
运行
复制
1type Post struct {
2  ID     uint64
3  UserID uint64
4  Title  string
5  Body   string
6}

Declare the interface and a struct.

代码语言:javascript
代码运行次数:0
运行
复制
 1package main
 2
 3import "github.com/tinrab/kit"
 4
 5type PostRepository interface {
 6  kit.Dependency
 7  GetPostsByUser(userID uint64) ([]Post, error)
 8}
 9
10type mySQLPostRepository struct {
11  Database SQLDatabase `inject:"database"`
12}
13
14func NewMySQLPostRepository() PostRepository {
15  return &mySQLPostRepository{}
16}
17
18func (r *mySQLPostRepository) Open() error {
19  return nil
20}
21
22func (r *mySQLPostRepository) Close() {
23}

The GetPostsByUser function queries posts by user's ID.

代码语言:javascript
代码运行次数:0
运行
复制
 1func (r *mySQLPostRepository) GetPostsByUser(userID uint64) ([]Post, error) {
 2  rows, err := r.Database.SQL().Query("SELECT * FROM posts WHERE user_id = ?", userID)
 3  if err != nil {
 4    return nil, err
 5  }
 6
 7  var post Post
 8  var posts []Post
 9  for rows.Next() {
10    err = rows.Scan(&post.ID, &post.UserID, &post.Title, &post.Body)
11    if err != nil {
12      return nil, err
13    }
14    posts = append(posts, post)
15  }
16
17  return posts, nil
18}

Blog service

The blog service uses previously implemented repositories to provide an API for reading user profiles.

代码语言:javascript
代码运行次数:0
运行
复制
 1package main
 2
 3import "github.com/tinrab/kit"
 4
 5type UserProfile struct {
 6  User  User
 7  Posts []Post
 8}
 9
10type BlogService interface {
11  kit.Dependency
12  GetUserProfile(userID uint64) (*UserProfile, error)
13}
14
15type blogServiceImpl struct {
16  UserRepository UserRepository `inject:"user.repository"`
17  PostRepository PostRepository `inject:"post.repository"`
18}
19
20func NewBlogService() BlogService {
21  return &blogServiceImpl{}
22}
23
24func (*blogServiceImpl) Open() error {
25  return nil
26}
27
28func (*blogServiceImpl) Close() {
29}

Both fields should contain non-nil instances, if properly resolved.

代码语言:javascript
代码运行次数:0
运行
复制
 1func (s *blogServiceImpl) GetUserProfile(userID uint64) (*UserProfile, error) {
 2  user, err := s.UserRepository.GetUserByID(userID)
 3  if err != nil {
 4    return nil, err
 5  }
 6  posts, err := s.PostRepository.GetPostsByUser(userID)
 7  if err != nil {
 8    return nil, err
 9  }
10  return &UserProfile{
11    User:  *user,
12    Posts: posts,
13  }, nil
14}

Resolving dependencies

To inject all dependencies, first provide them by name, then call Resolve function.

代码语言:javascript
代码运行次数:0
运行
复制
 1di := kit.NewDependencyInjection()
 2
 3di.Provide("database", NewMySQLDatabase("root:123456@tcp(127.0.0.1:3306)/blog"))
 4di.Provide("user.repository", NewMySQLUserRepository())
 5di.Provide("post.repository", NewMySQLPostRepository())
 6di.Provide("blog.service", NewBlogService())
 7
 8if err := di.Resolve(); err != nil {
 9  log.Fatal(err)
10}
代码语言:javascript
代码运行次数:0
运行
复制

Resolve will first call Open function of every dependency, and then inject them based on tags.

A dependency can be retrieved by name and used freely.

代码语言:javascript
代码运行次数:0
运行
复制
 1blogService := di.Get("blog.service").(BlogService)
 2
 3profile, err := blogService.GetUserProfile(1)
 4if err != nil {
 5  log.Fatal(err)
 6}
 7
 8fmt.Println(profile.User.Name)
 9for _, post := range profile.Posts {
10  fmt.Println(post.Title, "-", post.Body)
11}
代码语言:javascript
代码运行次数:0
运行
复制

Testing

Dependency injection is especially helpful during testing.

Here, user and post repositories are mocked in order to test blog service.

Write a fake repository that implements the UserRepository interface.

代码语言:javascript
代码运行次数:0
运行
复制
代码语言:javascript
代码运行次数:0
运行
复制
 1package main
 2
 3import (
 4  "errors"
 5  "testing"
 6
 7  "github.com/stretchr/testify/assert"
 8  "github.com/tinrab/kit"
 9)
10
11type userRepositoryStub struct {
12  users map[uint64]*User
13}
14
15func (r *userRepositoryStub) Open() error {
16  r.users = map[uint64]*User{
17    1: &User{ID: 1, Name: "User1"},
18    2: &User{ID: 2, Name: "User2"},
19    3: &User{ID: 3, Name: "User3"},
20  }
21  return nil
22}
23
24func (r *userRepositoryStub) Close() {
25}
26
27func (r *userRepositoryStub) GetUserByID(id uint64) (*User, error) {
28  if user, ok := r.users[id]; ok {
29    return user, nil
30  }
31  return nil, errors.New("User not found")
32}
代码语言:javascript
代码运行次数:0
运行
复制

And the same for PostRepository interface.

代码语言:javascript
代码运行次数:0
运行
复制
代码语言:javascript
代码运行次数:0
运行
复制
 1type postRepositoryStub struct {
 2  postsByUserID map[uint64][]Post
 3}
 4
 5func (r *postRepositoryStub) Open() error {
 6  r.postsByUserID = map[uint64][]Post{
 7    1: []Post{
 8      Post{ID: 1, UserID: 1, Title: "A", Body: "A"},
 9      Post{ID: 2, UserID: 1, Title: "B", Body: "B"},
10    },
11  }
12  return nil
13}
14
15func (r *postRepositoryStub) Close() {
16}
17
18func (r *postRepositoryStub) GetPostsByUser(userID uint64) ([]Post, error) {
19  if posts, ok := r.postsByUserID[userID]; ok {
20    return posts, nil
21  }
22  return []Post{}, nil
23}
代码语言:javascript
代码运行次数:0
运行
复制

Here's how a unit test could look like.

代码语言:javascript
代码运行次数:0
运行
复制
 1package main
 2
 3import (
 4  "errors"
 5  "testing"
 6
 7  "github.com/stretchr/testify/assert"
 8  "github.com/tinrab/kit"
 9)
10
11func TestBlog(t *testing.T) {
12  di := kit.NewDependencyInjection()
13
14  di.Provide("database", NewMySQLDatabase("root:123456@tcp(127.0.0.1:3306)/blog"))
15  di.Provide("user.repository", &userRepositoryStub{})
16  di.Provide("post.repository", &postRepositoryStub{})
17  di.Provide("blog.service", NewBlogService())
18
19  if err := di.Resolve(); err != nil {
20    t.Fatal(err)
21  }
22
23  blogService := di.Get("blog.service").(BlogService)
24  profile, err := blogService.GetUserProfile(1)
25  if err != nil {
26    t.Fatal(err)
27  }
28
29  assert.Equal(t, "User1", profile.User.Name)
30  assert.Equal(t, uint64(1), profile.Posts[0].UserID)
31  assert.Equal(t, "A", profile.Posts[0].Title)
32  assert.Equal(t, "A", profile.Posts[0].Body)
33  assert.Equal(t, uint64(1), profile.Posts[1].UserID)
34  assert.Equal(t, "B", profile.Posts[1].Title)
35  assert.Equal(t, "B", profile.Posts[1].Body)
36}

版权申明:内容来源网络,版权归原创者所有。除非无法确认,我们都会标明作者及出处,如有侵权烦请告知,我们会立即删除并表示歉意。谢谢。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-08-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Golang语言社区 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Declaring services
    • Database
    • User repository
    • Post repository
    • Blog service
  • Resolving dependencies
  • To inject all dependencies, first provide them by name, then call Resolve function.
  • Testing
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档