package tdmap import ( "context" "database/sql" "errors" "fmt" "reflect" "strings" ) func NewMapper() *Mapper { return &Mapper{} } // TableNamer 定义了一个获取表名的方法。 type TableNamer interface { TableName() string } type Mapper struct { structMateMap SyncMap[string, *StructMate] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息 } func (b *Mapper) scanStruct(data ...any) error { for _, datum := range data { if reflect.TypeOf(datum).Kind() != reflect.Ptr { //return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum)) return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(datum)) } if mate, err := scan(datum); err != nil { return err } else { b.structMateMap.Store(mate.UniqueTypeName, mate) } } return nil } // 提取 struct 内的信息 // @param data 包含 struct 类型的数据 // @return 返回一个 map[表名][]*TableRowMateria, 如果超级表名为空,则表示普通表 func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, error) { result := make(map[string][]*TableRowMateria) for _, item := range data { tf, vf := getReflectTypeAndValue(item) uniqueTypeName := getUniqueTypeName(tf) // 获取表名 var tableName string { if vf.CanAddr() { if a, ok := vf.Addr().Interface().(TableNamer); ok { tableName = a.TableName() } } if tableName == "" { if a, ok := item.(TableNamer); ok { tableName = a.TableName() } } if tableName == "" { return nil, fmt.Errorf("not import TableName() string func for struct type: %s", uniqueTypeName) } } mate, ok := b.structMateMap.Load(uniqueTypeName) if !ok { return nil, fmt.Errorf("not found struct type: %s", uniqueTypeName) } materia := &TableRowMateria{ SuperTableName: mate.SuperTableName, TableName: tableName, TagColumns: make([]string, 0, len(mate.TaggedFieldNames)), TagValues: make([]any, 0, len(mate.TaggedFieldNames)), Columns: make([]string, 0, len(mate.DBAnnotatedNames)), Values: make([]any, 0, len(mate.DBAnnotatedNames)), } for _, name := range mate.DBAnnotatedNames { // db tag 名称 -- 数据库列名 dbColumn := mate.Filed2DBNameCache[name] field := vf.FieldByIndex(mate.Field2IndexCache[name]) for field.Kind() == reflect.Ptr { if field.IsNil() { break } field = field.Elem() } // 字段值 dbValue := field.Interface() materia.Columns = append(materia.Columns, dbColumn) materia.Values = append(materia.Values, dbValue) } for _, name := range mate.TaggedFieldNames { // db tag 名称 -- 数据库列名 tagColumn := mate.Filed2DBNameCache[name] field := vf.FieldByIndex(mate.Field2IndexCache[name]) for field.Kind() == reflect.Ptr { if field.IsNil() { break } field = field.Elem() } // 字段值 tagValue := field.Interface() materia.TagColumns = append(materia.TagColumns, tagColumn) materia.TagValues = append(materia.TagValues, tagValue) } // 添加到结果 key := materia.TableName result[key] = append(result[key], materia) } return result, nil } func (b *Mapper) ToInsertSQL(data ...any) (string, error) { // 扫描 struct 类型数据 if err := b.scanStruct(data...); err != nil { return "", fmt.Errorf("failed to scan struct: %w", err) } // 提取 struct 内的信息 tableMap, err := b.extractStructData(data...) if err != nil { return "", fmt.Errorf("failed to extract struct data: %w", err) } if len(tableMap) == 0 { return "", fmt.Errorf("no data to insert") } // 构建 insert 语句 /* INSERT INTO 表名1 USING 超级表名1 (tag列名) TAGS(tag值) (列名) VALUES (),() 表名2 USING 超级表名2 (tag列名)TAGS(tag值) (列名) VALUES (),() */ var buf strings.Builder buf.WriteString("INSERT INTO \n") for _, materials := range tableMap { if len(materials) == 0 { continue } var ( rowSql string err error ) if materials[0].SuperTableName == "" { // 表名1 (列名) VALUES (),() rowSql, err = buildInsertStatementForNormalTable(materials...) } else { // 表名1 USING 超级表名1 (tag列名) TAGS(tag值) (列名) VALUES (),() rowSql, err = buildInsertStatementForSuperTable(materials...) } if err != nil { return "", fmt.Errorf("failed to build insert statement: %w", err) } buf.WriteString(rowSql) buf.WriteString(" \n") } return buf.String(), nil } // ScanRows 扫描多行数据 // A Superior Option: Use ScanRowsWithContext instead. func (b *Mapper) ScanRows(target any, rows *sql.Rows) error { return b.ScanRowsWithContext(context.Background(), target, rows) } // ScanRowsWithContext 扫描多行数据 // @param ctx 上下文 // @param target 结构体的指针 或 slice 指针 结构体需要使用 db 标签注解 // @param rows 数据库返回的行数据 // @return 返回错误 func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.Rows) error { // 确保rows不为nil if rows == nil { return errors.New("rows cannot be nil") } // 获取target的反射类型和值 vf := reflect.ValueOf(target) // 检查target是否为指针类型 if vf.Kind() != reflect.Ptr { return errors.New("target must be a pointer to a struct or a slice of structs") } if vf.IsNil() { return errors.New("target is nil, please pass a pointer to a struct or a slice of structs") } vf = vf.Elem() // 解引用指针以获取目标值 switch vf.Kind() { case reflect.Struct: // 提取 struct 内的信息 if err := b.scanStruct(target); err != nil { return err } // target是指向单个struct的指针,扫描一行数据 if !rows.Next() { return sql.ErrNoRows } return b.scanRow(vf, rows) case reflect.Slice: // target是slice的指针 sliceElementType := vf.Type().Elem() if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct { return errors.New("target must be a pointer to a slice of structs") } // 提取 slice 中 struct 内的信息 if err := b.scanStruct(reflect.New(sliceElementType.Elem()).Interface()); err != nil { // 这里在抛出 不是指针类型 return err } for rows.Next() { select { case <-ctx.Done(): vf.SetLen(0) // 清空 slice return ctx.Err() default: } // 创建slice元素的新实例 newStruct := reflect.New(sliceElementType.Elem()) if err := b.scanRow(newStruct, rows); err != nil { return err } // 将新实例添加到slice中 vf.Set(reflect.Append(vf, newStruct)) } if vf.Len() == 0 { // 如果slice为空,返回 sql.ErrNoRows return sql.ErrNoRows } return nil default: return errors.New("target must be a pointer to a struct or a slice of structs") } } // scanRow 用于扫描单行数据到结构体 func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error { target = reflect.Indirect(target) uniqueTypeName := getUniqueTypeName(target.Type()) mate, ok := b.structMateMap.Load(uniqueTypeName) if !ok { return fmt.Errorf("not found struct type mate: %s", uniqueTypeName) } // 假设我们有一个函数来获取列的值 columns, err := rows.Columns() if err != nil { return err } // 准备目标结构体的地址 dest := make([]interface{}, len(columns)) for i, colName := range columns { // 拿到缓存的字段索引 if idx, ok := mate.DBName2IndexCache[colName]; ok { dest[i] = target.FieldByIndex(idx).Addr().Interface() } else { return fmt.Errorf("no corresponding field found for column %s", colName) } } // 扫描数据 return rows.Scan(dest...) }