# -*- coding: utf-8 -*- """ @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) @description: """ import os import sys from codecs import open from sklearn.model_selection import train_test_split from lxml import etree from xml.dom import minidom sys.path.append('../..') from pycorrector.utils.io_utils import save_json from pycorrector import traditional2simplified pwd_path = os.path.abspath(os.path.dirname(__file__)) def read_data(fp): for fn in os.listdir(fp): if fn.endswith('ing.sgml'): with open(os.path.join(fp, fn), 'r', encoding='utf-8', errors='ignore') as f: item = [] for line in f: if line.strip().startswith(' 0: yield ''.join(item) item = [line.strip()] elif line.strip().startswith('<'): item.append(line.strip()) def proc_item(item): """ 处理训练数据集 Args: item: Returns: list """ root = etree.XML(item) passages = dict() mistakes = [] for passage in root.xpath('/ESSAY/TEXT/PASSAGE'): passages[passage.get('id')] = traditional2simplified(passage.text) for mistake in root.xpath('/ESSAY/MISTAKE'): mistakes.append({'id': mistake.get('id'), 'location': int(mistake.get('location')) - 1, 'wrong': traditional2simplified(mistake.xpath('./WRONG/text()')[0].strip()), 'correction': traditional2simplified(mistake.xpath('./CORRECTION/text()')[0].strip())}) rst_items = dict() def get_passages_by_id(pgs, _id): p = pgs.get(_id) if p: return p _id = _id[:-1] + str(int(_id[-1]) + 1) p = pgs.get(_id) if p: return p raise ValueError(f'passage not found by {_id}') for mistake in mistakes: if mistake['id'] not in rst_items.keys(): rst_items[mistake['id']] = {'original_text': get_passages_by_id(passages, mistake['id']), 'wrong_ids': [], 'correct_text': get_passages_by_id(passages, mistake['id'])} ori_text = rst_items[mistake['id']]['original_text'] cor_text = rst_items[mistake['id']]['correct_text'] if len(ori_text) == len(cor_text): if ori_text[mistake['location']] in mistake['wrong']: rst_items[mistake['id']]['wrong_ids'].append(mistake['location']) wrong_char_idx = mistake['wrong'].index(ori_text[mistake['location']]) start = mistake['location'] - wrong_char_idx end = start + len(mistake['wrong']) rst_items[mistake['id']][ 'correct_text'] = f'{cor_text[:start]}{mistake["correction"]}{cor_text[end:]}' else: print(f'error line:\n{mistake["id"]}\n{ori_text}\n{cor_text}') rst = [] for k in rst_items.keys(): if len(rst_items[k]['correct_text']) == len(rst_items[k]['original_text']): rst.append({'id': k, **rst_items[k]}) else: text = rst_items[k]['correct_text'] rst.append({'id': k, 'correct_text': text, 'original_text': text, 'wrong_ids': []}) return rst def proc_test_set(fp): """ 生成sighan15的测试集 Args: fp: Returns: """ inputs = dict() with open(os.path.join(fp, 'SIGHAN15_CSC_TestInput.txt'), 'r', encoding='utf-8') as f: for line in f: pid = line[5:14] text = line[16:].strip() inputs[pid] = text rst = [] with open(os.path.join(fp, 'SIGHAN15_CSC_TestTruth.txt'), 'r', encoding='utf-8') as f: for line in f: pid = line[0:9] mistakes = line[11:].strip().split(', ') if len(mistakes) <= 1: text = traditional2simplified(inputs[pid]) rst.append({'id': pid, 'original_text': text, 'wrong_ids': [], 'correct_text': text}) else: wrong_ids = [] original_text = inputs[pid] cor_text = inputs[pid] for i in range(len(mistakes) // 2): idx = int(mistakes[2 * i]) - 1 cor_char = mistakes[2 * i + 1] wrong_ids.append(idx) cor_text = f'{cor_text[:idx]}{cor_char}{cor_text[idx + 1:]}' original_text = traditional2simplified(original_text) cor_text = traditional2simplified(cor_text) if len(original_text) != len(cor_text): print('error line:\n', pid) print(original_text) print(cor_text) continue rst.append({'id': pid, 'original_text': original_text, 'wrong_ids': wrong_ids, 'correct_text': cor_text}) return rst def parse_cged_file(file_dir): rst = [] for fn in os.listdir(file_dir): if fn.endswith('.xml'): path = os.path.join(file_dir, fn) print('Parse data from %s' % path) dom_tree = minidom.parse(path) docs = dom_tree.documentElement.getElementsByTagName('DOC') for doc in docs: id = '' text = '' texts = doc.getElementsByTagName('TEXT') for i in texts: id = i.getAttribute('id') # Input the text text = i.childNodes[0].data.strip() # Input the correct text correction = doc.getElementsByTagName('CORRECTION')[0]. \ childNodes[0].data.strip() wrong_ids = [] for error in doc.getElementsByTagName('ERROR'): start_off = error.getAttribute('start_off') end_off = error.getAttribute('end_off') if start_off and end_off: for i in range(int(start_off), int(end_off)+1): wrong_ids.append(i) source = text.strip() target = correction.strip() pair = [source, target] if pair not in rst: rst.append({'id': id, 'original_text': source, 'wrong_ids': wrong_ids, 'correct_text': target }) save_json(rst, os.path.join(pwd_path, 'output/cged.json')) return rst def main(): # 注意:该训练样本较少,仅作为模型测试使用 # parse_cged_file(os.path.join(pwd_path, '../data/cn/CGED/')) sighan15_dir = os.path.join(pwd_path, '../data/cn/sighan_2015/') rst_items = [] test_lst = proc_test_set(sighan15_dir) for item in read_data(sighan15_dir): rst_items += proc_item(item) # 拆分训练与测试 print('data_size:', len(rst_items)) train_lst, dev_lst = train_test_split(rst_items, test_size=0.1, random_state=42) save_json(train_lst, os.path.join(pwd_path, 'output/train.json')) save_json(dev_lst, os.path.join(pwd_path, 'output/dev.json')) save_json(test_lst, os.path.join(pwd_path, 'output/test.json')) if __name__ == '__main__': main()