加入了predict 的代码,还可以提升准确率,没有解决对于 由于 bert 编码问题带来的 对应不到原句
最大为512,数据都筛选过了
import tensorflow as tf
import numpy as np
from bert import modeling
from bert import tokenization
from bert import optimization
import os
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('train_batch_size',6,'define the train batch size')
flags.DEFINE_integer('num_train_epochs',3,'define the num train epochs')
flags.DEFINE_float('warmup_proportion',0.1,'define the warmup proportion')
flags.DEFINE_float('learning_rate',5e-5,'the initial learning rate for adam')
flags.DEFINE_bool('is_training',True,'define weather fine-tune the bert model')
flags.DEFINE_integer('max_sentence_len',512,'define the max len of sentence')
flags.DEFINE_bool('task_train',True,'define the train task')
flags.DEFINE_bool('task_predict',True,'define the predict task')
def get_start_end_index(text,subtext):
for i in range(len(text)):
if text[i:i+len(subtext)] == subtext:
return (i,i+len(subtext)-1)
return (-1,-1)
train_data = []
with open('data/train_data.txt',encoding='UTF-8') as fp:
strLines = fp.readlines()
strLines = [item.strip() for item in strLines]
strLines = [eval(item) for item in strLines]
train_data.extend(strLines)
test_data = []
with open('data/test_data.txt',encoding='UTF-8') as fp:
strLines = fp.readlines()
strLines = [item.strip() for item in strLines]
strLines = [eval(item) for item in strLines]
test_data.extend(strLines)
# config_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_config.json'
# checkpoint_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_model.ckpt'
# dict_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\vocab.txt'
config_path = './bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = './bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = './bert/chinese_L-12_H-768_A-12/vocab.txt'
bert_config = modeling.BertConfig.from_json_file(config_path)
tokenizer = tokenization.FullTokenizer(vocab_file=dict_path,do_lower_case=True)
def input_str_concat(inputList):
a