Skip to content

[Bug] ddddocr不支持自定义模型:预处理标准化参数、输出节点名、后处理方式均与官方模型不同 #4

@130347665

Description

@130347665

问题描述

使用 dddd-trainer 训练的自定义 ONNX 模型时,Go 版本识别结果与 Python ddddocr 不一致。经过深入对比 Python 源码,发现自定义模型与官方模型在多个关键环节存在差异,当前 Go 实现仅支持官方模型。

复现步骤

  1. 使用 dddd-trainer 训练自定义模型
  2. 同一张验证码图片分别用 Python ddddocr 和 Go go-ocr 进行识别
  3. 对比识别结果

实际行为

  • Python 识别结果:vusc(正确)
  • Go 识别结果:lwv(错误)

差异分析

对比 Python ddddocr 源码,发现官方模型与自定义模型存在 4 个关键差异

项目 官方模型 自定义模型 (dddd-trainer)
标准化参数 (x - 0.5) / 0.5 (x - 0.456) / 0.224
输出节点名 "387" "output"
输出数据类型 float32 int64
后处理方式 argmax + CTC 解码 直接 CTC 解码

问题一:图像标准化参数

当前代码:

inputData[i] = float32(pix) / 255.0  // ❌ 缺少标准化

Python ddddocr:

# 官方模型 (use_import_onnx=False)
image = (image - 0.5) / 0.5

# 自定义模型 (use_import_onnx=True, channel=1)
image = (image - 0.456) / 0.224

问题二:image.Gray 的 Stride 处理

当前代码:

for i, pix := range grayImg.Pix {  // ❌ 没有考虑 Stride
    inputData[i] = ...
}

修正:

for y := 0; y < targetH; y++ {
    for x := 0; x < targetW; x++ {
        pix := grayImg.Pix[y*grayImg.Stride+x]  // ✅ 使用 Stride
        inputData[y*targetW+x] = ...
    }
}

问题三:输出节点名和数据类型不同

当前代码(官方模型):

outputValue := outputValues["387"]
outputData, err := ort.GetTensorData[float32](outputValue)

自定义模型需要:

outputValue := outputValues["output"]
outputData, err := ort.GetTensorData[int64](outputValue)

问题四:后处理方式不同

  • 官方模型:输出是 float32 概率矩阵,需要 argmax 找最大值索引再 CTC 解码
  • 自定义模型:输出直接是 int64 字符索引,直接 CTC 解码即可

建议修复方案

1. 在 Config 和 Engine 中添加标志

type Config struct {
    ModelPath      string
    DictPath       string
    DetModelPath   string
    UseCustomModel bool   // 新增:true = 自定义模型
}

type Engine struct {
    ocrSession     *ort.Session
    detSession     *ort.Session
    dict           []string
    useCustomModel bool   // 新增
}

2. 修改 NewEngine

func NewEngine(cfg Config) (*Engine, error) {
    // ...
    engine := &Engine{
        useCustomModel: cfg.UseCustomModel,
    }
    // ...
}

3. 修改 preprocessOCR

func (e *Engine) preprocessOCR(img image.Image) ([]float32, []int64, error) {
    targetH := 64
    dstImg := imageutil.Resize(img, 0, targetH)
    targetW := dstImg.Bounds().Dx()

    grayImg := imageutil.Grayscale(dstImg)
    inputData := make([]float32, 1*1*targetH*targetW)

    for y := 0; y < targetH; y++ {
        for x := 0; x < targetW; x++ {
            pix := grayImg.Pix[y*grayImg.Stride+x]
            normalized := float32(pix) / 255.0

            if e.useCustomModel {
                inputData[y*targetW+x] = (normalized - 0.456) / 0.224
            } else {
                inputData[y*targetW+x] = (normalized - 0.5) / 0.5
            }
        }
    }

    shape := []int64{1, 1, int64(targetH), int64(targetW)}
    return inputData, shape, nil
}

4. 修改 Classification

func (e *Engine) Classification(img image.Image) (string, error) {
    if e.ocrSession == nil {
        return "", fmt.Errorf("OCR 引擎未初始化")
    }

    inputData, inputShape, err := e.preprocessOCR(img)
    if err != nil {
        return "", err
    }

    inputTensor, err := ort.NewTensor(inputShape, inputData)
    if err != nil {
        return "", err
    }
    defer inputTensor.Destroy()

    inputValues := map[string]*ort.Value{
        "input1": inputTensor,
    }

    outputValues, err := e.ocrSession.Run(inputValues)
    if err != nil {
        return "", fmt.Errorf("OCR 推理失败: %w", err)
    }

    if e.useCustomModel {
        // 自定义模型:输出节点 "output",类型 int64
        outputValue := outputValues["output"]
        defer outputValue.Destroy()

        outputData, err := ort.GetTensorData[int64](outputValue)
        if err != nil {
            return "", fmt.Errorf("获取 OCR 输出数据失败: %w", err)
        }

        var sb strings.Builder
        lastIdx := int64(-1)
        for _, idx := range outputData {
            if idx != 0 && idx != lastIdx {
                if int(idx) < len(e.dict) {
                    sb.WriteString(e.dict[idx])
                }
            }
            lastIdx = idx
        }
        return sb.String(), nil

    } else {
        // 官方模型:输出节点 "387",类型 float32
        outputValue := outputValues["387"]
        defer outputValue.Destroy()

        outputData, err := ort.GetTensorData[float32](outputValue)
        if err != nil {
            return "", fmt.Errorf("获取 OCR 输出数据失败: %w", err)
        }

        seqLen := int(math.Ceil(float64(inputShape[3]) / 8.0))
        return e.postprocessOCR(outputData, seqLen), nil
    }
}

5. 使用方式

// 官方模型
engine, _ := ddddocr.NewEngine(ddddocr.Config{
    ModelPath:      "common_old.onnx",
    DictPath:       "common_old.json",
    UseCustomModel: false,
})

// 自定义模型
engine, _ := ddddocr.NewEngine(ddddocr.Config{
    ModelPath:      "my_model.onnx",
    DictPath:       "my_charsets.json",
    UseCustomModel: true,
})

环境信息

  • Go 版本:1.21+
  • 操作系统:Windows
  • Python ddddocr 版本:最新版
  • 模型来源:dddd-trainer 自训练模型

参考

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions