gorm 源码阅读之 schema
发布于: 2021 年 04 月 01 日
基于 https://github.com/go-gorm/gorm v1.21.x
数据映射
我们来看看不用 gorm 直接使用 mysql 驱动连接查询一个 user 的代码
type User struct { Id int `json:"id"` Name string `json:"name"` Age int `json:"age"`}func queryUser(db *sql.DB){ fmt.Println("query times:",i) user := new(User) row := db.QueryRow("select * from users where id=?", 1) //row.scan中的字段必须是按照数据库存入字段的顺序,否则报错 if err := row.Scan(&user.Id,&user.Name,&user.Age); err != nil{ fmt.Printf("scan failed, err:%v",err) return } fmt.Println(*user) }}
复制代码
来看看 gorm 是怎么做的
func queryUser(db *gorm.DB){ var user User if err := db.First(&user).Error; nil == err { fmt.Printf("user:%+v\n", user) }}
复制代码
gorm 帮我们解决了数据字段和 struct 结构的数据映射,这也是一个 orm 的关键所在。
gorm 其实使用的是反射,在 gorm 源码里面,Schema 是数据映射的这块的核心,
Schema 实际上就是保存了目标对象,也就是 user 的数据结构
type DB struct { *Config Error error RowsAffected int64 Statement *Statement clone int}// db.Statement.Schema就是Schema对象了type Statement struct { //... Schema *schema.Schema //...}
复制代码
初始化 Schema
初始化 db 的时候,初始化 db.Statement 并没有初始化 Schema
func Open(dialector Dialector, opts ...Option) (db *DB, err error) { //...可以看到初始化Statement的时候,Statement.Schema并没有初始化 db.Statement = &Statement{ DB: db, ConnPool: db.ConnPool, Context: context.Background(), Clauses: map[string]clause.Clause{}, } return}
//从入口开始找Schema的初始化堆栈{ var user User if err := db.First(&user).Error; nil != err { fmt.Printf("err:%v\n", err) } fmt.Printf("user:%+v\n", user)}
//First调用tx.callbacks.Query().Execute(tx)func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { //... //这里的Statement.Dest,也就是数据最终要保存到的user tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return}
//processor的Execute里初始化Schemafunc (p *processor) Execute(db *DB) { //... //有两种方式,可以告知gorm目标对象的结构 //第一种:db.Model(&user).Update("name", "hello"),直接传一个Model //第二种:db.First(&user),到First里,这个&user就是Dest if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model }
if stmt.Model != nil { //stmt.Parse利用反射开始读取Model的结构 if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) } } } //... //执行callback,真正的查询的地方 for _, f := range p.fns { f(db) } //...}
//这里的value就是&userfunc (stmt *Statement) Parse(value interface{}) (err error) { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] return }
stmt.Table = stmt.Schema.Table } return err}
复制代码
要理解 Parse 里面的代码要先将下面关于反射的测试代码理解了,
也可以参考https://xie.infoq.cn/article/09c3cfe6918c938d266af2182
var blog Blog structValue := reflect.ValueOf(blog) fmt.Printf("structValue value:%v\n", structValue) // structValue value:{0 0} fmt.Printf("structValue type:%v\n", structValue.Type()) // structValue type:main.Blog fmt.Printf("structValue kind:%v\n", structValue.Type().Kind()) // structValue kind:struct fmt.Printf("structValue CanSet:%t\n", structValue.CanSet()) // structValue CanSet:false
structPtrValue := reflect.ValueOf(&blog) fmt.Printf("structPtrValue value:%v\n", structPtrValue) // structPtrValue value:&{0 0} fmt.Printf("structPtrValue type:%v\n", structPtrValue.Type()) // structPtrValue type:*main.Blog fmt.Printf("structPtrValue type Elem:%v\n", structPtrValue.Type().Elem()) // structPtrValue type Elem:main.Blog fmt.Printf("structPtrValue kind:%v\n", structPtrValue.Type().Kind()) // structPtrValue kind:ptr fmt.Printf("structPtrValue CanSet:%t\n", structPtrValue.CanSet()) // structPtrValue CanSet:false fmt.Printf("structPtrValue Elem CanSet:%t\n", structPtrValue.Elem().CanSet()) // structPtrValue Elem CanSet:true
复制代码
现在来看 Parse(),Parse 就是利用反射将 user 里的每一个属性读到 Field 里去,比如属性名称,属性索引等
//这里的dest就是&userfunc Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) }
modelType := reflect.ValueOf(dest).Type()//*main.User //modelType.Kind()是等于reflect.Ptr的 for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem()//main.User }
//传入的dest必须是struct if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) }
//如果缓存有保存这个schema,直接返回 if v, ok := cacheStore.Load(modelType); ok { s := v.(*Schema) <-s.initialized return s, s.err }
//新建一个main.User modelValue := reflect.New(modelType) //modelType.Name() == User tableName := namer.TableName(modelType.Name()) //如果有自定义的tablename,就用自定义的那个tablename if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table }
//初始化Schema对象 schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, initialized: make(chan struct{}), }
defer func() { if schema.err != nil { logger.Default.Error(context.Background(), schema.err.Error()) cacheStore.Delete(modelType) } }()
//得到user的所有的fieldStruct for i := 0; i < modelType.NumField(); i++ { //ast.IsExported判断Field是不是对外开放的,也就是属性名以大写开头 if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { //将fieldStruct解析到Field if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) } else { schema.Fields = append(schema.Fields, field) } } } //... return schema, schema.err}
复制代码
如何给 user 赋值
现在 user 的 Schema 已经解析完了,那在哪里将数据库中的数据,设置到 user 里面去呢?
//callbacks/query.go里面定义了真正的查询的地方func Query(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db)
if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) return } defer rows.Close()
gorm.Scan(rows, db, false) } }}
//Scan将数据库的数据丢到user里func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) db.RowsAffected = 0
//First()里面tx.Statement.Dest = dest,就是这里的dest switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: //... default: //由于user是struct会走到这里 Schema := db.Statement.Schema switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: //... case reflect.Struct, reflect.Ptr: if db.Statement.ReflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) }
if initialized || rows.Next() { //构建values准备接收数据库的值 for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { //如果找不到column对应的field,判断column是否包含"__",这是什么情况? if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue } } values[idx] = &sql.RawBytes{} } else { values[idx] = &sql.RawBytes{} } }
db.RowsAffected++ //读取数据库返回的值 db.AddError(rows.Scan(values...))
for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { //在callbacks.go, processor.Execute里stmt.ReflectValue = reflect.ValueOf(stmt.Dest) //将读取的值赋值给user的field field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if value.IsNil() { continue } relValue.Set(reflect.New(relValue.Type().Elem())) }
field.Set(relValue, values[idx]) } } } } } } }
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { db.AddError(ErrRecordNotFound) }}
复制代码
划线
评论
复制
发布于: 2021 年 04 月 01 日阅读数: 69
版权声明: 本文为 InfoQ 作者【werben】的原创文章。
原文链接:【http://xie.infoq.cn/article/6692a870ae0514f8d657c6990】。文章转载请联系作者。
werben
关注
还未添加个人签名 2018.01.08 加入
还未添加个人简介











评论