写点什么

gorm 源码阅读之 schema

用户头像
werben
关注
发布于: 2021 年 04 月 01 日

基于 https://github.com/go-gorm/gorm v1.21.x

数据映射

我们来看看不用 gorm 直接使用 mysql 驱动连接查询一个 user 的代码


type User struct {  Id       int    `json:"id"`  Name     string `json:"name"`  Age      int    `json:"age"`}func queryUser(db *sql.DB){        fmt.Println("query times:",i)        user := new(User)        row := db.QueryRow("select * from users where id=?", 1)        //row.scan中的字段必须是按照数据库存入字段的顺序,否则报错        if err := row.Scan(&user.Id,&user.Name,&user.Age); err != nil{            fmt.Printf("scan failed, err:%v",err)            return        }        fmt.Println(*user)    }}
复制代码


来看看 gorm 是怎么做的


func queryUser(db *gorm.DB){  var user User    if err := db.First(&user).Error; nil == err {    fmt.Printf("user:%+v\n", user)    }}
复制代码


gorm 帮我们解决了数据字段和 struct 结构的数据映射,这也是一个 orm 的关键所在。


gorm 其实使用的是反射,在 gorm 源码里面,Schema 是数据映射的这块的核心,


Schema 实际上就是保存了目标对象,也就是 user 的数据结构


type DB struct {  *Config  Error        error  RowsAffected int64  Statement    *Statement  clone        int}// db.Statement.Schema就是Schema对象了type Statement struct {  //...  Schema               *schema.Schema    //...}
复制代码

初始化 Schema

初始化 db 的时候,初始化 db.Statement 并没有初始化 Schema


func Open(dialector Dialector, opts ...Option) (db *DB, err error) {  //...可以看到初始化Statement的时候,Statement.Schema并没有初始化  db.Statement = &Statement{    DB:       db,    ConnPool: db.ConnPool,    Context:  context.Background(),    Clauses:  map[string]clause.Clause{},  }  return}
//从入口开始找Schema的初始化堆栈{ var user User if err := db.First(&user).Error; nil != err { fmt.Printf("err:%v\n", err) } fmt.Printf("user:%+v\n", user)}
//First调用tx.callbacks.Query().Execute(tx)func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { //... //这里的Statement.Dest,也就是数据最终要保存到的user tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return}
//processor的Execute里初始化Schemafunc (p *processor) Execute(db *DB) { //... //有两种方式,可以告知gorm目标对象的结构 //第一种:db.Model(&user).Update("name", "hello"),直接传一个Model //第二种:db.First(&user),到First里,这个&user就是Dest if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model }
if stmt.Model != nil { //stmt.Parse利用反射开始读取Model的结构 if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) } } } //... //执行callback,真正的查询的地方 for _, f := range p.fns { f(db) } //...}
//这里的value就是&userfunc (stmt *Statement) Parse(value interface{}) (err error) { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] return }
stmt.Table = stmt.Schema.Table } return err}
复制代码


要理解 Parse 里面的代码要先将下面关于反射的测试代码理解了,

也可以参考https://xie.infoq.cn/article/09c3cfe6918c938d266af2182


  var blog Blog  structValue := reflect.ValueOf(blog)  fmt.Printf("structValue value:%v\n", structValue)              // structValue value:{0  0}  fmt.Printf("structValue type:%v\n", structValue.Type())        // structValue type:main.Blog  fmt.Printf("structValue kind:%v\n", structValue.Type().Kind()) // structValue kind:struct  fmt.Printf("structValue CanSet:%t\n", structValue.CanSet())    // structValue CanSet:false
structPtrValue := reflect.ValueOf(&blog) fmt.Printf("structPtrValue value:%v\n", structPtrValue) // structPtrValue value:&{0 0} fmt.Printf("structPtrValue type:%v\n", structPtrValue.Type()) // structPtrValue type:*main.Blog fmt.Printf("structPtrValue type Elem:%v\n", structPtrValue.Type().Elem()) // structPtrValue type Elem:main.Blog fmt.Printf("structPtrValue kind:%v\n", structPtrValue.Type().Kind()) // structPtrValue kind:ptr fmt.Printf("structPtrValue CanSet:%t\n", structPtrValue.CanSet()) // structPtrValue CanSet:false fmt.Printf("structPtrValue Elem CanSet:%t\n", structPtrValue.Elem().CanSet()) // structPtrValue Elem CanSet:true
复制代码


现在来看 Parse(),Parse 就是利用反射将 user 里的每一个属性读到 Field 里去,比如属性名称,属性索引等


//这里的dest就是&userfunc Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {  if dest == nil {    return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)  }
modelType := reflect.ValueOf(dest).Type()//*main.User //modelType.Kind()是等于reflect.Ptr的 for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem()//main.User }
//传入的dest必须是struct if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) }
//如果缓存有保存这个schema,直接返回 if v, ok := cacheStore.Load(modelType); ok { s := v.(*Schema) <-s.initialized return s, s.err }
//新建一个main.User modelValue := reflect.New(modelType) //modelType.Name() == User tableName := namer.TableName(modelType.Name()) //如果有自定义的tablename,就用自定义的那个tablename if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table }
//初始化Schema对象 schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, initialized: make(chan struct{}), }
defer func() { if schema.err != nil { logger.Default.Error(context.Background(), schema.err.Error()) cacheStore.Delete(modelType) } }()
//得到user的所有的fieldStruct for i := 0; i < modelType.NumField(); i++ { //ast.IsExported判断Field是不是对外开放的,也就是属性名以大写开头 if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { //将fieldStruct解析到Field if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) } else { schema.Fields = append(schema.Fields, field) } } } //... return schema, schema.err}
复制代码

如何给 user 赋值

现在 user 的 Schema 已经解析完了,那在哪里将数据库中的数据,设置到 user 里面去呢?


//callbacks/query.go里面定义了真正的查询的地方func Query(db *gorm.DB) {  if db.Error == nil {    BuildQuerySQL(db)
if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) return } defer rows.Close()
gorm.Scan(rows, db, false) } }}
//Scan将数据库的数据丢到user里func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) db.RowsAffected = 0
//First()里面tx.Statement.Dest = dest,就是这里的dest switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: //... default: //由于user是struct会走到这里 Schema := db.Statement.Schema switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: //... case reflect.Struct, reflect.Ptr: if db.Statement.ReflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) }
if initialized || rows.Next() { //构建values准备接收数据库的值 for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { //如果找不到column对应的field,判断column是否包含"__",这是什么情况? if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue } } values[idx] = &sql.RawBytes{} } else { values[idx] = &sql.RawBytes{} } }
db.RowsAffected++ //读取数据库返回的值 db.AddError(rows.Scan(values...))
for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { //在callbacks.go, processor.Execute里stmt.ReflectValue = reflect.ValueOf(stmt.Dest) //将读取的值赋值给user的field field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if value.IsNil() { continue } relValue.Set(reflect.New(relValue.Type().Elem())) }
field.Set(relValue, values[idx]) } } } } } } }
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { db.AddError(ErrRecordNotFound) }}
复制代码


发布于: 2021 年 04 月 01 日阅读数: 69
用户头像

werben

关注

还未添加个人签名 2018.01.08 加入

还未添加个人简介

评论

发布
暂无评论
gorm源码阅读之schema