glm模型询问批量csv glm_ask_batch_csv
文档:from transformers import AutoTokeniz…
链接:http://note.youdao.com/noteshare?id=6cc558b45e7def3df017a0e54cca03fb&sub=B1A08392E5104E27BA51369C5CBE7F97
添加链接描述
from transformers import AutoTokenizer, AutoModel
from top.starp.util import json_util
from top.starp.util import list_util
from top.starp.util import time_util
from top.starp.util import file_util
model_path='/j05025/home/work/chatglm-6b'
# "THUDM/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()
now_time_str=time_util.get_now_time_str()
def questions_list_ask(questions_list,startIdx=130000,ask_type="legal"):
idx=startIdx
out_dir=fr"/j05025/datasets/chatglm_test_out/chatglm_test_out_{now_time_str}_start_{startIdx}"
for one in questions_list:
content=one
ask=content
response, history = model.chat(tokenizer, ask, history=[])
log_data={
"ask":ask,
"response":response,
}
out_path=f"{out_dir}/{ask_type}_start_{startIdx}_{idx}.json"
print("out_path",out_path)
json_util.json_to_file(log_data,out_path)
idx+=1
file_name="/j05025/datasets/starp/LegalQA-master/LegalQA-manual-train.csv"
# questions=file_util.read_csv_col(file_name,"question: subject")
questions=file_util.read_csv_col(file_name,"question: body")
questions=set(questions)
questions_list=list(questions)
questions_list_ask(questions_list,startIdx=0,ask_type="LegalQA")
import csv
def read_csv_col(filename = 'data.csv',col_name='name',encoding='utf-8'):
"""
col_name='name' 的 那一列 []
"""
# filename = 'data.csv' # 替换成你的 CSV 文件路径和名称
col_data_list=[]
with open(filename, 'r', encoding=encoding) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
name = row[col_name]
# print(name)
col_data_list.append(name)
return col_data_list