perf(scan): 使用缓存减少对象扫描
抽取函数,优化变量命名 支持了传入数组类型
This commit is contained in:
178
mapping.go
178
mapping.go
@@ -19,35 +19,141 @@ type TableNamer interface {
|
||||
}
|
||||
|
||||
type Mapper struct {
|
||||
structMateMap SyncMap[string, *StructMate] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息
|
||||
structMateMap SyncMap[string, *StructMeta] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息
|
||||
}
|
||||
|
||||
func (b *Mapper) scanStruct(data ...any) error {
|
||||
for _, datum := range data {
|
||||
if reflect.TypeOf(datum).Kind() != reflect.Ptr {
|
||||
//return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum))
|
||||
return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(datum))
|
||||
return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(data))
|
||||
}
|
||||
|
||||
if mate, err := scan(datum); err != nil {
|
||||
// 分析结构体并缓存结果
|
||||
if _, err := b.analyzeStructWithCache(datum); err != nil {
|
||||
return err
|
||||
} else {
|
||||
b.structMateMap.Store(mate.UniqueTypeName, mate)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// analyzeStructWithCache 带缓存的结构体分析
|
||||
func (b *Mapper) analyzeStructWithCache(data any) (*StructMeta, error) {
|
||||
// 获取结构体的类型
|
||||
t, v := extractReflectInfo(data)
|
||||
|
||||
// 确保是结构体类型
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("input data is not a struct")
|
||||
}
|
||||
|
||||
// 获取类型的唯一标识符
|
||||
uniqueTypeName := buildTypeKey(t)
|
||||
|
||||
// 先检查缓存
|
||||
if cached, ok := b.structMateMap.Load(uniqueTypeName); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 初始化结果结构体
|
||||
sr := StructMeta{
|
||||
UniqueTypeName: uniqueTypeName,
|
||||
DBName2IndexCache: make(map[string][]int, t.NumField()),
|
||||
Field2IndexCache: make(map[string][]int, t.NumField()),
|
||||
Field2DBNameCache: make(map[string]string, t.NumField()),
|
||||
DBAnnotatedNames: make([]string, 0, t.NumField()),
|
||||
TaggedFieldNames: make([]string, 0, t.NumField()),
|
||||
SuperTableName: "",
|
||||
}
|
||||
|
||||
// 遍历结构体的字段
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
|
||||
if field.Anonymous {
|
||||
// 处理匿名结构体
|
||||
if field.Type.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
// 如果指针是 nil,创建一个该类型的零值实例
|
||||
zeroValue := reflect.Zero(field.Type.Elem())
|
||||
fieldValue = zeroValue
|
||||
} else {
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
if !fieldValue.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 递归分析匿名结构体,使用相同的缓存
|
||||
subResult, err := b.analyzeStructWithCache(fieldValue.Interface())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并分析结果
|
||||
for k, v := range subResult.DBName2IndexCache {
|
||||
sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], i)
|
||||
sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], v...)
|
||||
}
|
||||
|
||||
for k, v := range subResult.Field2IndexCache {
|
||||
sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], i)
|
||||
sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], v...)
|
||||
}
|
||||
for k, v := range subResult.Field2DBNameCache {
|
||||
sr.Field2DBNameCache[k] = v
|
||||
}
|
||||
|
||||
sr.DBAnnotatedNames = append(sr.DBAnnotatedNames, subResult.DBAnnotatedNames...)
|
||||
sr.TaggedFieldNames = append(sr.TaggedFieldNames, subResult.TaggedFieldNames...)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理普通字段
|
||||
columnName := field.Tag.Get("db")
|
||||
if columnName == "-" || columnName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(sr.Field2IndexCache[field.Name]) > 0 {
|
||||
return nil, fmt.Errorf("duplicate field [%s %s `db:%s`]", field.Name, field.Type.Name(), columnName)
|
||||
}
|
||||
|
||||
sr.Field2IndexCache[field.Name] = append([]int{}, i)
|
||||
sr.DBName2IndexCache[columnName] = append(sr.DBName2IndexCache[columnName], i)
|
||||
sr.Field2DBNameCache[field.Name] = columnName
|
||||
|
||||
// 检查字段是否有taos注解
|
||||
if field.Tag.Get("taos") == "tag" {
|
||||
sr.TaggedFieldNames = append(sr.TaggedFieldNames, field.Name)
|
||||
} else {
|
||||
sr.DBAnnotatedNames = append(sr.DBAnnotatedNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if sTableName, ok := tryGetSuperTableName(data); ok {
|
||||
sr.SuperTableName = sTableName
|
||||
}
|
||||
|
||||
// 将结果存入缓存
|
||||
b.structMateMap.Store(uniqueTypeName, &sr)
|
||||
|
||||
return &sr, nil
|
||||
}
|
||||
|
||||
// 提取 struct 内的信息
|
||||
// @param data 包含 struct 类型的数据
|
||||
// @return 返回一个 map[表名][]*TableRowMateria, 如果超级表名为空,则表示普通表
|
||||
func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, error) {
|
||||
result := make(map[string][]*TableRowMateria)
|
||||
// @return 返回一个 map[表名][]*TableRowMaterial, 如果超级表名为空,则表示普通表
|
||||
func (b *Mapper) extractTableMaterial(data ...any) (map[string][]*TableRowMaterial, error) {
|
||||
result := make(map[string][]*TableRowMaterial)
|
||||
|
||||
for _, item := range data {
|
||||
tf, vf := getReflectTypeAndValue(item)
|
||||
uniqueTypeName := getUniqueTypeName(tf)
|
||||
tf, vf := extractReflectInfo(item)
|
||||
uniqueTypeName := buildTypeKey(tf)
|
||||
|
||||
// 获取表名
|
||||
var tableName string
|
||||
@@ -74,7 +180,7 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria,
|
||||
return nil, fmt.Errorf("not found struct type: %s", uniqueTypeName)
|
||||
}
|
||||
|
||||
materia := &TableRowMateria{
|
||||
trMaterial := &TableRowMaterial{
|
||||
SuperTableName: mate.SuperTableName,
|
||||
TableName: tableName,
|
||||
TagColumns: make([]string, 0, len(mate.TaggedFieldNames)),
|
||||
@@ -85,7 +191,7 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria,
|
||||
|
||||
for _, name := range mate.DBAnnotatedNames {
|
||||
// db tag 名称 -- 数据库列名
|
||||
dbColumn := mate.Filed2DBNameCache[name]
|
||||
dbColumn := mate.Field2DBNameCache[name]
|
||||
|
||||
field := vf.FieldByIndex(mate.Field2IndexCache[name])
|
||||
for field.Kind() == reflect.Ptr {
|
||||
@@ -97,13 +203,13 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria,
|
||||
// 字段值
|
||||
dbValue := field.Interface()
|
||||
|
||||
materia.Columns = append(materia.Columns, dbColumn)
|
||||
materia.Values = append(materia.Values, dbValue)
|
||||
trMaterial.Columns = append(trMaterial.Columns, dbColumn)
|
||||
trMaterial.Values = append(trMaterial.Values, dbValue)
|
||||
}
|
||||
|
||||
for _, name := range mate.TaggedFieldNames {
|
||||
// db tag 名称 -- 数据库列名
|
||||
tagColumn := mate.Filed2DBNameCache[name]
|
||||
tagColumn := mate.Field2DBNameCache[name]
|
||||
field := vf.FieldByIndex(mate.Field2IndexCache[name])
|
||||
for field.Kind() == reflect.Ptr {
|
||||
if field.IsNil() {
|
||||
@@ -114,13 +220,13 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria,
|
||||
// 字段值
|
||||
tagValue := field.Interface()
|
||||
|
||||
materia.TagColumns = append(materia.TagColumns, tagColumn)
|
||||
materia.TagValues = append(materia.TagValues, tagValue)
|
||||
trMaterial.TagColumns = append(trMaterial.TagColumns, tagColumn)
|
||||
trMaterial.TagValues = append(trMaterial.TagValues, tagValue)
|
||||
}
|
||||
|
||||
// 添加到结果
|
||||
key := materia.TableName
|
||||
result[key] = append(result[key], materia)
|
||||
key := trMaterial.TableName
|
||||
result[key] = append(result[key], trMaterial)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -135,17 +241,17 @@ func (b *Mapper) normalizeData(data ...any) []any {
|
||||
var normalizedData []any
|
||||
|
||||
for _, item := range data {
|
||||
if item == nil {
|
||||
if isNil(item) {
|
||||
// 跳过nil值,避免后续反射调用panic
|
||||
continue
|
||||
}
|
||||
itemType := reflect.TypeOf(item)
|
||||
if itemType.Kind() == reflect.Slice {
|
||||
if itemType.Kind() == reflect.Slice || itemType.Kind() == reflect.Array {
|
||||
// 处理切片类型参数,展平成单个元素
|
||||
sliceValue := reflect.ValueOf(item)
|
||||
for i := 0; i < sliceValue.Len(); i++ {
|
||||
elem := sliceValue.Index(i).Interface()
|
||||
if elem == nil {
|
||||
if isNil(elem) {
|
||||
// 跳过nil元素
|
||||
continue
|
||||
}
|
||||
@@ -159,18 +265,10 @@ func (b *Mapper) normalizeData(data ...any) []any {
|
||||
|
||||
// 转换所有元素为指针类型,便于后续一致性处理
|
||||
for i, item := range normalizedData {
|
||||
if item == nil {
|
||||
if isNil(item) {
|
||||
continue
|
||||
}
|
||||
itemType := reflect.TypeOf(item)
|
||||
if itemType.Kind() != reflect.Ptr {
|
||||
// 创建对应类型的指针
|
||||
itemPtr := reflect.New(itemType)
|
||||
// 设置指针指向的值为元素本身
|
||||
itemPtr.Elem().Set(reflect.ValueOf(item))
|
||||
// 替换切片中的元素为指针类型
|
||||
normalizedData[i] = itemPtr.Interface()
|
||||
}
|
||||
normalizedData[i] = ensurePtr(item)
|
||||
}
|
||||
return normalizedData
|
||||
}
|
||||
@@ -185,11 +283,11 @@ func (b *Mapper) ToInsertSQL(data ...any) (string, error) {
|
||||
|
||||
// 扫描 struct 类型数据
|
||||
if err := b.scanStruct(normalizedData...); err != nil {
|
||||
return "", fmt.Errorf("failed to scan struct: %w", err)
|
||||
return "", fmt.Errorf("failed to analyzeStruct struct: %w", err)
|
||||
}
|
||||
|
||||
// 提取 struct 内的信息
|
||||
tableMap, err := b.extractStructData(normalizedData...)
|
||||
tableMap, err := b.extractTableMaterial(normalizedData...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to extract struct data: %w", err)
|
||||
}
|
||||
@@ -246,7 +344,7 @@ func (b *Mapper) ScanRows(target any, rows *sql.Rows) error {
|
||||
// @return 返回错误
|
||||
func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.Rows) error {
|
||||
// 确保rows不为nil
|
||||
if rows == nil {
|
||||
if isNil(rows) {
|
||||
return errors.New("rows cannot be nil")
|
||||
}
|
||||
|
||||
@@ -275,8 +373,8 @@ func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return b.scanRow(vf, rows)
|
||||
case reflect.Slice:
|
||||
// target是slice的指针
|
||||
case reflect.Slice, reflect.Array:
|
||||
// target是slice或array,扫描多行数据
|
||||
sliceElementType := vf.Type().Elem()
|
||||
if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct {
|
||||
return errors.New("target must be a pointer to a slice of structs")
|
||||
@@ -314,7 +412,7 @@ func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.
|
||||
func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error {
|
||||
target = reflect.Indirect(target)
|
||||
|
||||
uniqueTypeName := getUniqueTypeName(target.Type())
|
||||
uniqueTypeName := buildTypeKey(target.Type())
|
||||
mate, ok := b.structMateMap.Load(uniqueTypeName)
|
||||
if !ok {
|
||||
return fmt.Errorf("not found struct type mate: %s", uniqueTypeName)
|
||||
@@ -387,5 +485,5 @@ func (b *Mapper) ScanRowsToMapWithContext(ctx context.Context, rows *sql.Rows) (
|
||||
}
|
||||
result = append(result, rowMap)
|
||||
}
|
||||
return result, err
|
||||
return result, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user