使用trainconll进行作为训练的数据集trainconll部分内容如下浙 B-prov江 E-prov杭 B-city州 I-city市 E-city江 B-district干 I-district区 E-district九 B-town堡 I-town镇 E-town三 B-community村 I-community村 E-community一 B-poi区 E-poi浙 B-prov江
import os import codecs from collections import defaultdict import sklearn_crfsuite from sklearn_crfsuite import metrics
读取数据集
def read_data(file_path): with codecs.open(file_path, 'r', encoding='utf-8') as f: lines = f.readlines() data = [] sent = [] for line in lines: line = line.strip() if line == '': data.append(sent) sent = [] else: word, tag = line.split() sent.append((word, tag)) if sent: data.append(sent) return data
特征提取函数
def word2features(sent, i): word = sent[i][0] features = { 'bias': 1.0, 'word': word, 'word_len': len(word), 'is_digit': word.isdigit(), 'is_alpha': word.isalpha(), 'is_upper': word.isupper(), 'is_lower': word.islower(), 'is_title': word.istitle(), 'prev_word': '' if i == 0 else sent[i-1][0], 'next_word': '' if i == len(sent)-1 else sent[i+1][0], 'prev_tag': '' if i == 0 else sent[i-1][1], 'next_tag': '' if i == len(sent)-1 else sent[i+1][1], } return features
def sent2features(sent): return [word2features(sent, i) for i in range(len(sent))]
def sent2labels(sent): return [label for word, label in sent]
def sent2tokens(sent): return [word for word, label in sent]
训练模型
def train_model(train_data): X_train = [sent2features(sent) for sent in train_data] y_train = [sent2labels(sent) for sent in train_data] crf = sklearn_crfsuite.CRF( algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True ) crf.fit(X_train, y_train) return crf
预测结果
def predict(crf, test_data): X_test = [sent2features(sent) for sent in test_data] y_pred = crf.predict(X_test) return y_pred
将预测结果写入文件
def write_result(file_path, test_data, y_pred): with codecs.open(file_path, 'w', encoding='utf-8') as f: for i in range(len(test_data)): for j in range(len(test_data[i])): word = test_data[i][j][0] tag = y_pred[i][j] f.write('{}\u0001{}\u0001{}\n'.format(i+1, word, tag))
评估模型
def evaluate(test_data, y_pred): y_true = [sent2labels(sent) for sent in test_data] print(metrics.flat_classification_report(y_true, y_pred))
if name == 'main': train_file = 'train.conll' test_file = '1.txt' result_file = '对对对队_addr_parsing_runid.txt' train_data = read_data(train_file) test_data = read_data(test_file) crf = train_model(train_data) y_pred = predict(crf, test_data) write_result(result_file, test_data, y_pred) evaluate(test_data, y_pred
原文地址: https://www.cveoy.top/t/topic/fzd3 著作权归作者所有。请勿转载和采集!