gorm 源码阅读之 schema
发布于: 2021 年 04 月 01 日
基于 https://github.com/go-gorm/gorm v1.21.x
数据映射
我们来看看不用 gorm 直接使用 mysql 驱动连接查询一个 user 的代码
type User struct {
Id int `json:"id"`
Name string `json:"name"`
Age int `json:"age"`
}
func queryUser(db *sql.DB){
fmt.Println("query times:",i)
user := new(User)
row := db.QueryRow("select * from users where id=?", 1)
//row.scan中的字段必须是按照数据库存入字段的顺序,否则报错
if err := row.Scan(&user.Id,&user.Name,&user.Age); err != nil{
fmt.Printf("scan failed, err:%v",err)
return
}
fmt.Println(*user)
}
}
复制代码
来看看 gorm 是怎么做的
func queryUser(db *gorm.DB){
var user User
if err := db.First(&user).Error; nil == err {
fmt.Printf("user:%+v\n", user)
}
}
复制代码
gorm 帮我们解决了数据字段和 struct 结构的数据映射,这也是一个 orm 的关键所在。
gorm 其实使用的是反射,在 gorm 源码里面,Schema 是数据映射的这块的核心,
Schema 实际上就是保存了目标对象,也就是 user 的数据结构
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
// db.Statement.Schema就是Schema对象了
type Statement struct {
//...
Schema *schema.Schema
//...
}
复制代码
初始化 Schema
初始化 db 的时候,初始化 db.Statement 并没有初始化 Schema
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
//...可以看到初始化Statement的时候,Statement.Schema并没有初始化
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
return
}
//从入口开始找Schema的初始化堆栈
{
var user User
if err := db.First(&user).Error; nil != err {
fmt.Printf("err:%v\n", err)
}
fmt.Printf("user:%+v\n", user)
}
//First调用tx.callbacks.Query().Execute(tx)
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
//...
//这里的Statement.Dest,也就是数据最终要保存到的user
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
//processor的Execute里初始化Schema
func (p *processor) Execute(db *DB) {
//...
//有两种方式,可以告知gorm目标对象的结构
//第一种:db.Model(&user).Update("name", "hello"),直接传一个Model
//第二种:db.First(&user),到First里,这个&user就是Dest
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
if stmt.Model != nil {
//stmt.Parse利用反射开始读取Model的结构
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
}
}
//...
//执行callback,真正的查询的地方
for _, f := range p.fns {
f(db)
}
//...
}
//这里的value就是&user
func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
return
}
stmt.Table = stmt.Schema.Table
}
return err
}
复制代码
要理解 Parse 里面的代码要先将下面关于反射的测试代码理解了,
也可以参考https://xie.infoq.cn/article/09c3cfe6918c938d266af2182
var blog Blog
structValue := reflect.ValueOf(blog)
fmt.Printf("structValue value:%v\n", structValue) // structValue value:{0 0}
fmt.Printf("structValue type:%v\n", structValue.Type()) // structValue type:main.Blog
fmt.Printf("structValue kind:%v\n", structValue.Type().Kind()) // structValue kind:struct
fmt.Printf("structValue CanSet:%t\n", structValue.CanSet()) // structValue CanSet:false
structPtrValue := reflect.ValueOf(&blog)
fmt.Printf("structPtrValue value:%v\n", structPtrValue) // structPtrValue value:&{0 0}
fmt.Printf("structPtrValue type:%v\n", structPtrValue.Type()) // structPtrValue type:*main.Blog
fmt.Printf("structPtrValue type Elem:%v\n", structPtrValue.Type().Elem()) // structPtrValue type Elem:main.Blog
fmt.Printf("structPtrValue kind:%v\n", structPtrValue.Type().Kind()) // structPtrValue kind:ptr
fmt.Printf("structPtrValue CanSet:%t\n", structPtrValue.CanSet()) // structPtrValue CanSet:false
fmt.Printf("structPtrValue Elem CanSet:%t\n", structPtrValue.Elem().CanSet()) // structPtrValue Elem CanSet:true
复制代码
现在来看 Parse(),Parse 就是利用反射将 user 里的每一个属性读到 Field 里去,比如属性名称,属性索引等
//这里的dest就是&user
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
modelType := reflect.ValueOf(dest).Type()//*main.User
//modelType.Kind()是等于reflect.Ptr的
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()//main.User
}
//传入的dest必须是struct
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
//如果缓存有保存这个schema,直接返回
if v, ok := cacheStore.Load(modelType); ok {
s := v.(*Schema)
<-s.initialized
return s, s.err
}
//新建一个main.User
modelValue := reflect.New(modelType)
//modelType.Name() == User
tableName := namer.TableName(modelType.Name())
//如果有自定义的tablename,就用自定义的那个tablename
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
//初始化Schema对象
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
//得到user的所有的fieldStruct
for i := 0; i < modelType.NumField(); i++ {
//ast.IsExported判断Field是不是对外开放的,也就是属性名以大写开头
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
//将fieldStruct解析到Field
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
} else {
schema.Fields = append(schema.Fields, field)
}
}
}
//...
return schema, schema.err
}
复制代码
如何给 user 赋值
现在 user 的 Schema 已经解析完了,那在哪里将数据库中的数据,设置到 user 里面去呢?
//callbacks/query.go里面定义了真正的查询的地方
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun && db.Error == nil {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer rows.Close()
gorm.Scan(rows, db, false)
}
}
}
//Scan将数据库的数据丢到user里
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
db.RowsAffected = 0
//First()里面tx.Statement.Dest = dest,就是这里的dest
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
//...
default:
//由于user是struct会走到这里
Schema := db.Statement.Schema
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
//...
case reflect.Struct, reflect.Ptr:
if db.Statement.ReflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if initialized || rows.Next() {
//构建values准备接收数据库的值
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
//如果找不到column对应的field,判断column是否包含"__",这是什么情况?
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
//读取数据库返回的值
db.AddError(rows.Scan(values...))
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
//在callbacks.go, processor.Execute里stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
//将读取的值赋值给user的field
field.Set(db.Statement.ReflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
}
}
}
}
}
}
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
db.AddError(ErrRecordNotFound)
}
}
复制代码
划线
评论
复制
发布于: 2021 年 04 月 01 日阅读数: 69
版权声明: 本文为 InfoQ 作者【werben】的原创文章。
原文链接:【http://xie.infoq.cn/article/6692a870ae0514f8d657c6990】。文章转载请联系作者。
werben
关注
还未添加个人签名 2018.01.08 加入
还未添加个人简介
评论