问题描述
使用 dddd-trainer 训练的自定义 ONNX 模型时,Go 版本识别结果与 Python ddddocr 不一致。经过深入对比 Python 源码,发现自定义模型与官方模型在多个关键环节存在差异,当前 Go 实现仅支持官方模型。
复现步骤
- 使用 dddd-trainer 训练自定义模型
- 同一张验证码图片分别用 Python ddddocr 和 Go go-ocr 进行识别
- 对比识别结果
实际行为
- Python 识别结果:
vusc(正确)
- Go 识别结果:
lwv(错误)
差异分析
对比 Python ddddocr 源码,发现官方模型与自定义模型存在 4 个关键差异:
| 项目 |
官方模型 |
自定义模型 (dddd-trainer) |
| 标准化参数 |
(x - 0.5) / 0.5 |
(x - 0.456) / 0.224 |
| 输出节点名 |
"387" |
"output" |
| 输出数据类型 |
float32 |
int64 |
| 后处理方式 |
argmax + CTC 解码 |
直接 CTC 解码 |
问题一:图像标准化参数
当前代码:
inputData[i] = float32(pix) / 255.0 // ❌ 缺少标准化
Python ddddocr:
# 官方模型 (use_import_onnx=False)
image = (image - 0.5) / 0.5
# 自定义模型 (use_import_onnx=True, channel=1)
image = (image - 0.456) / 0.224
问题二:image.Gray 的 Stride 处理
当前代码:
for i, pix := range grayImg.Pix { // ❌ 没有考虑 Stride
inputData[i] = ...
}
修正:
for y := 0; y < targetH; y++ {
for x := 0; x < targetW; x++ {
pix := grayImg.Pix[y*grayImg.Stride+x] // ✅ 使用 Stride
inputData[y*targetW+x] = ...
}
}
问题三:输出节点名和数据类型不同
当前代码(官方模型):
outputValue := outputValues["387"]
outputData, err := ort.GetTensorData[float32](outputValue)
自定义模型需要:
outputValue := outputValues["output"]
outputData, err := ort.GetTensorData[int64](outputValue)
问题四:后处理方式不同
- 官方模型:输出是
float32 概率矩阵,需要 argmax 找最大值索引再 CTC 解码
- 自定义模型:输出直接是
int64 字符索引,直接 CTC 解码即可
建议修复方案
1. 在 Config 和 Engine 中添加标志
type Config struct {
ModelPath string
DictPath string
DetModelPath string
UseCustomModel bool // 新增:true = 自定义模型
}
type Engine struct {
ocrSession *ort.Session
detSession *ort.Session
dict []string
useCustomModel bool // 新增
}
2. 修改 NewEngine
func NewEngine(cfg Config) (*Engine, error) {
// ...
engine := &Engine{
useCustomModel: cfg.UseCustomModel,
}
// ...
}
3. 修改 preprocessOCR
func (e *Engine) preprocessOCR(img image.Image) ([]float32, []int64, error) {
targetH := 64
dstImg := imageutil.Resize(img, 0, targetH)
targetW := dstImg.Bounds().Dx()
grayImg := imageutil.Grayscale(dstImg)
inputData := make([]float32, 1*1*targetH*targetW)
for y := 0; y < targetH; y++ {
for x := 0; x < targetW; x++ {
pix := grayImg.Pix[y*grayImg.Stride+x]
normalized := float32(pix) / 255.0
if e.useCustomModel {
inputData[y*targetW+x] = (normalized - 0.456) / 0.224
} else {
inputData[y*targetW+x] = (normalized - 0.5) / 0.5
}
}
}
shape := []int64{1, 1, int64(targetH), int64(targetW)}
return inputData, shape, nil
}
4. 修改 Classification
func (e *Engine) Classification(img image.Image) (string, error) {
if e.ocrSession == nil {
return "", fmt.Errorf("OCR 引擎未初始化")
}
inputData, inputShape, err := e.preprocessOCR(img)
if err != nil {
return "", err
}
inputTensor, err := ort.NewTensor(inputShape, inputData)
if err != nil {
return "", err
}
defer inputTensor.Destroy()
inputValues := map[string]*ort.Value{
"input1": inputTensor,
}
outputValues, err := e.ocrSession.Run(inputValues)
if err != nil {
return "", fmt.Errorf("OCR 推理失败: %w", err)
}
if e.useCustomModel {
// 自定义模型:输出节点 "output",类型 int64
outputValue := outputValues["output"]
defer outputValue.Destroy()
outputData, err := ort.GetTensorData[int64](outputValue)
if err != nil {
return "", fmt.Errorf("获取 OCR 输出数据失败: %w", err)
}
var sb strings.Builder
lastIdx := int64(-1)
for _, idx := range outputData {
if idx != 0 && idx != lastIdx {
if int(idx) < len(e.dict) {
sb.WriteString(e.dict[idx])
}
}
lastIdx = idx
}
return sb.String(), nil
} else {
// 官方模型:输出节点 "387",类型 float32
outputValue := outputValues["387"]
defer outputValue.Destroy()
outputData, err := ort.GetTensorData[float32](outputValue)
if err != nil {
return "", fmt.Errorf("获取 OCR 输出数据失败: %w", err)
}
seqLen := int(math.Ceil(float64(inputShape[3]) / 8.0))
return e.postprocessOCR(outputData, seqLen), nil
}
}
5. 使用方式
// 官方模型
engine, _ := ddddocr.NewEngine(ddddocr.Config{
ModelPath: "common_old.onnx",
DictPath: "common_old.json",
UseCustomModel: false,
})
// 自定义模型
engine, _ := ddddocr.NewEngine(ddddocr.Config{
ModelPath: "my_model.onnx",
DictPath: "my_charsets.json",
UseCustomModel: true,
})
环境信息
- Go 版本:1.21+
- 操作系统:Windows
- Python ddddocr 版本:最新版
- 模型来源:dddd-trainer 自训练模型
参考
问题描述
使用 dddd-trainer 训练的自定义 ONNX 模型时,Go 版本识别结果与 Python ddddocr 不一致。经过深入对比 Python 源码,发现自定义模型与官方模型在多个关键环节存在差异,当前 Go 实现仅支持官方模型。
复现步骤
实际行为
vusc(正确)lwv(错误)差异分析
对比 Python ddddocr 源码,发现官方模型与自定义模型存在 4 个关键差异:
(x - 0.5) / 0.5(x - 0.456) / 0.224"387""output"float32int64问题一:图像标准化参数
当前代码:
Python ddddocr:
问题二:image.Gray 的 Stride 处理
当前代码:
修正:
问题三:输出节点名和数据类型不同
当前代码(官方模型):
自定义模型需要:
问题四:后处理方式不同
float32概率矩阵,需要 argmax 找最大值索引再 CTC 解码int64字符索引,直接 CTC 解码即可建议修复方案
1. 在 Config 和 Engine 中添加标志
2. 修改 NewEngine
3. 修改 preprocessOCR
4. 修改 Classification
5. 使用方式
环境信息
参考
classification方法中use_import_onnx的分支处理