-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
运行效果:
推理失败是因为中文标点的关系。
请问我需要修改什么
使用的代码如下:
#!/usr/bin/env python3
--
# -*- coding: utf-8 -*-
import json
import torch
from model import BraLM, Vocab # 确保 model.py 在当前目录或在 PYTHONPATH 中
# ---------- 1. 启动阶段:只做一次 ----------
print("[INFO] 正在加载词表 …")
with open("vocab_wiki_4k.json") as f:
node_dict = json.load(f)
vocab = Vocab.from_node_dict(node_dict)
print("[INFO] 正在加载词频 …")
with open("word_frequency.json", "r") as f:
freq_dict = json.load(f)
# 构造 zero_freq_edges
zero_freq_edges = {}
for s in freq_dict:
zero_freq_edges[s] = []
for t in freq_dict[s]:
if freq_dict[s][t] == 0:
zero_freq_edges[s].append(t)
print("[INFO] 正在构建模型 …")
model = BraLM(hidden_size=32, zero_freq_edges=zero_freq_edges, vocab=vocab)
model.prepare_network(vocab)
print("[INFO] 正在加载权重 …")
state_dict = torch.load("/openbayes/input/input0/BriLLM0.5/model_zh.bin", weights_only=True)
model.load_state_dict(state_dict)
model.to_device("cuda:0")
print("[INFO] 模型已就绪,输入 exit 或 quit 可退出。\n")
# ---------- 2. 循环交互阶段 ----------
while True:
try:
head = input(">>> ")
except (EOFError, KeyboardInterrupt):
# Ctrl-D / Ctrl-C 也能优雅退出
print("\n[INFO] 用户中断,程序结束。")
break
if head.strip().lower() in {"exit", "quit"}:
print("[INFO] 用户请求退出,程序结束。")
break
if not head:
continue # 空输入直接跳过
# 推理
max_token = 32 - len(head)
try:
start = [vocab(head[i] + "->" + head[i + 1]) for i in range(len(head) - 1)]
ret = model.decode(start, vocab, max_token)
decode_tuple_list = [vocab.decode(p) for p in ret]
decode_sentence = decode_tuple_list[0][0] + "".join([p[-1] for p in decode_tuple_list])
print(decode_sentence)
except Exception as e:
print(f"[ERROR] 推理失败:{e}")
Metadata
Metadata
Assignees
Labels
No labels