You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
7.4 KiB
201 lines
7.4 KiB
![]()
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
|
||
|
@description:
|
||
|
"""
|
||
|
import os
|
||
|
import sys
|
||
|
from codecs import open
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
from lxml import etree
|
||
|
from xml.dom import minidom
|
||
|
|
||
|
sys.path.append('../..')
|
||
|
from pycorrector.utils.io_utils import save_json
|
||
|
from pycorrector import traditional2simplified
|
||
|
|
||
|
pwd_path = os.path.abspath(os.path.dirname(__file__))
|
||
|
|
||
|
|
||
|
def read_data(fp):
|
||
|
for fn in os.listdir(fp):
|
||
|
if fn.endswith('ing.sgml'):
|
||
|
with open(os.path.join(fp, fn), 'r', encoding='utf-8', errors='ignore') as f:
|
||
|
item = []
|
||
|
for line in f:
|
||
|
if line.strip().startswith('<ESSAY') and len(item) > 0:
|
||
|
yield ''.join(item)
|
||
|
item = [line.strip()]
|
||
|
elif line.strip().startswith('<'):
|
||
|
item.append(line.strip())
|
||
|
|
||
|
|
||
|
def proc_item(item):
|
||
|
"""
|
||
|
处理训练数据集
|
||
|
Args:
|
||
|
item:
|
||
|
Returns:
|
||
|
list
|
||
|
"""
|
||
|
root = etree.XML(item)
|
||
|
passages = dict()
|
||
|
mistakes = []
|
||
|
for passage in root.xpath('/ESSAY/TEXT/PASSAGE'):
|
||
|
passages[passage.get('id')] = traditional2simplified(passage.text)
|
||
|
for mistake in root.xpath('/ESSAY/MISTAKE'):
|
||
|
mistakes.append({'id': mistake.get('id'),
|
||
|
'location': int(mistake.get('location')) - 1,
|
||
|
'wrong': traditional2simplified(mistake.xpath('./WRONG/text()')[0].strip()),
|
||
|
'correction': traditional2simplified(mistake.xpath('./CORRECTION/text()')[0].strip())})
|
||
|
|
||
|
rst_items = dict()
|
||
|
|
||
|
def get_passages_by_id(pgs, _id):
|
||
|
p = pgs.get(_id)
|
||
|
if p:
|
||
|
return p
|
||
|
_id = _id[:-1] + str(int(_id[-1]) + 1)
|
||
|
p = pgs.get(_id)
|
||
|
if p:
|
||
|
return p
|
||
|
raise ValueError(f'passage not found by {_id}')
|
||
|
|
||
|
for mistake in mistakes:
|
||
|
if mistake['id'] not in rst_items.keys():
|
||
|
rst_items[mistake['id']] = {'original_text': get_passages_by_id(passages, mistake['id']),
|
||
|
'wrong_ids': [],
|
||
|
'correct_text': get_passages_by_id(passages, mistake['id'])}
|
||
|
|
||
|
ori_text = rst_items[mistake['id']]['original_text']
|
||
|
cor_text = rst_items[mistake['id']]['correct_text']
|
||
|
if len(ori_text) == len(cor_text):
|
||
|
if ori_text[mistake['location']] in mistake['wrong']:
|
||
|
rst_items[mistake['id']]['wrong_ids'].append(mistake['location'])
|
||
|
wrong_char_idx = mistake['wrong'].index(ori_text[mistake['location']])
|
||
|
start = mistake['location'] - wrong_char_idx
|
||
|
end = start + len(mistake['wrong'])
|
||
|
rst_items[mistake['id']][
|
||
|
'correct_text'] = f'{cor_text[:start]}{mistake["correction"]}{cor_text[end:]}'
|
||
|
else:
|
||
|
print(f'error line:\n{mistake["id"]}\n{ori_text}\n{cor_text}')
|
||
|
rst = []
|
||
|
for k in rst_items.keys():
|
||
|
if len(rst_items[k]['correct_text']) == len(rst_items[k]['original_text']):
|
||
|
rst.append({'id': k, **rst_items[k]})
|
||
|
else:
|
||
|
text = rst_items[k]['correct_text']
|
||
|
rst.append({'id': k, 'correct_text': text, 'original_text': text, 'wrong_ids': []})
|
||
|
return rst
|
||
|
|
||
|
|
||
|
def proc_test_set(fp):
|
||
|
"""
|
||
|
生成sighan15的测试集
|
||
|
Args:
|
||
|
fp:
|
||
|
Returns:
|
||
|
"""
|
||
|
inputs = dict()
|
||
|
with open(os.path.join(fp, 'SIGHAN15_CSC_TestInput.txt'), 'r', encoding='utf-8') as f:
|
||
|
for line in f:
|
||
|
pid = line[5:14]
|
||
|
text = line[16:].strip()
|
||
|
inputs[pid] = text
|
||
|
|
||
|
rst = []
|
||
|
with open(os.path.join(fp, 'SIGHAN15_CSC_TestTruth.txt'), 'r', encoding='utf-8') as f:
|
||
|
for line in f:
|
||
|
pid = line[0:9]
|
||
|
mistakes = line[11:].strip().split(', ')
|
||
|
if len(mistakes) <= 1:
|
||
|
text = traditional2simplified(inputs[pid])
|
||
|
rst.append({'id': pid,
|
||
|
'original_text': text,
|
||
|
'wrong_ids': [],
|
||
|
'correct_text': text})
|
||
|
else:
|
||
|
wrong_ids = []
|
||
|
original_text = inputs[pid]
|
||
|
cor_text = inputs[pid]
|
||
|
for i in range(len(mistakes) // 2):
|
||
|
idx = int(mistakes[2 * i]) - 1
|
||
|
cor_char = mistakes[2 * i + 1]
|
||
|
wrong_ids.append(idx)
|
||
|
cor_text = f'{cor_text[:idx]}{cor_char}{cor_text[idx + 1:]}'
|
||
|
original_text = traditional2simplified(original_text)
|
||
|
cor_text = traditional2simplified(cor_text)
|
||
|
if len(original_text) != len(cor_text):
|
||
|
print('error line:\n', pid)
|
||
|
print(original_text)
|
||
|
print(cor_text)
|
||
|
continue
|
||
|
rst.append({'id': pid,
|
||
|
'original_text': original_text,
|
||
|
'wrong_ids': wrong_ids,
|
||
|
'correct_text': cor_text})
|
||
|
|
||
|
return rst
|
||
|
|
||
|
|
||
|
def parse_cged_file(file_dir):
|
||
|
rst = []
|
||
|
for fn in os.listdir(file_dir):
|
||
|
if fn.endswith('.xml'):
|
||
|
path = os.path.join(file_dir, fn)
|
||
|
print('Parse data from %s' % path)
|
||
|
|
||
|
dom_tree = minidom.parse(path)
|
||
|
docs = dom_tree.documentElement.getElementsByTagName('DOC')
|
||
|
for doc in docs:
|
||
|
id = ''
|
||
|
text = ''
|
||
|
texts = doc.getElementsByTagName('TEXT')
|
||
|
for i in texts:
|
||
|
id = i.getAttribute('id')
|
||
|
# Input the text
|
||
|
text = i.childNodes[0].data.strip()
|
||
|
# Input the correct text
|
||
|
correction = doc.getElementsByTagName('CORRECTION')[0]. \
|
||
|
childNodes[0].data.strip()
|
||
|
wrong_ids = []
|
||
|
for error in doc.getElementsByTagName('ERROR'):
|
||
|
start_off = error.getAttribute('start_off')
|
||
|
end_off = error.getAttribute('end_off')
|
||
|
if start_off and end_off:
|
||
|
for i in range(int(start_off), int(end_off)+1):
|
||
|
wrong_ids.append(i)
|
||
|
source = text.strip()
|
||
|
target = correction.strip()
|
||
|
|
||
|
pair = [source, target]
|
||
|
if pair not in rst:
|
||
|
rst.append({'id': id,
|
||
|
'original_text': source,
|
||
|
'wrong_ids': wrong_ids,
|
||
|
'correct_text': target
|
||
|
})
|
||
|
save_json(rst, os.path.join(pwd_path, 'output/cged.json'))
|
||
|
return rst
|
||
|
|
||
|
|
||
|
def main():
|
||
|
# 注意:该训练样本较少,仅作为模型测试使用
|
||
|
# parse_cged_file(os.path.join(pwd_path, '../data/cn/CGED/'))
|
||
|
sighan15_dir = os.path.join(pwd_path, '../data/cn/sighan_2015/')
|
||
|
rst_items = []
|
||
|
test_lst = proc_test_set(sighan15_dir)
|
||
|
for item in read_data(sighan15_dir):
|
||
|
rst_items += proc_item(item)
|
||
|
|
||
|
# 拆分训练与测试
|
||
|
print('data_size:', len(rst_items))
|
||
|
train_lst, dev_lst = train_test_split(rst_items, test_size=0.1, random_state=42)
|
||
|
save_json(train_lst, os.path.join(pwd_path, 'output/train.json'))
|
||
|
save_json(dev_lst, os.path.join(pwd_path, 'output/dev.json'))
|
||
|
save_json(test_lst, os.path.join(pwd_path, 'output/test.json'))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|