8000 New MCPDict format by lernanto · Pull Request #40 · lernanto/sincomp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

New MCPDict format #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions scripts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,20 @@ def prepare(args: argparse.Namespace, config: dict) -> None:
# 根据方言数据构建词典
data = load_data(config['datasets'])

minfreq = config.get('min_freq')
minfreq = config.get('min_frequency')
maxcat = config.get('max_categories')

vocabs = {}
for name, columns in config['columns'].items():
voc = []
for c in columns:
dic = sincomp.auxiliary.make_dict(
data[c],
# 当 minfreq 为字典时,可以为不同列指定不同的最小频次
minfreq=minfreq.get(name) if isinstance(minfreq, dict) else minfreq,
sort='value'
# 当 minfreq, maxcat 为字典时,可以为不同列指定不同的值
min_frequency=minfreq.get(name) if isinstance(minfreq, dict) \
else minfreq,
max_categories=maxcat.get(name) if isinstance(maxcat, dict) \
else maxcat,
)
logging.info(f'{c} vocabulary size = f{dic.shape[0]}')
voc.append(dic.index.tolist())
Expand Down
176 changes: 121 additions & 55 deletions src/sincomp/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""


import argparse
import itertools
import logging
import numpy
Expand All @@ -28,11 +29,17 @@
import sklearn.feature_extraction.text
import sklearn.linear_model
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import sklearn.pipeline
import typing

from .preprocess import transform
from . import datasets, preprocess


logger = logging.getLogger(__name__)
if not logger.hasHandlers():
logger.addHandler(logging.StreamHandler())


def prepare(
Expand All @@ -59,7 +66,7 @@ def prepare(
traditional: transformed 中包含的字表的繁体,如原始字表未提供繁体,则为空
"""

transformed = transform(
transformed = preprocess.transform(
dataset.fillna({'initial': '', 'final': '', 'tone': ''}),
index='cid',
columns='did',
Expand Down Expand Up @@ -528,7 +535,7 @@ def align_no_cid(
),
sklearn.preprocessing.Normalizer('l2')
), i) for i in range(base.shape[1])]).fit_transform(base.fillna(''))
logging.debug(f'fit base dialect embeddings with {matrix[mask].shape} data.')
logger.debug(f'fit base dialect embeddings with {matrix[mask].shape} data.')
emb = sklearn.decomposition.TruncatedSVD(emb_size).fit(matrix[mask]) \
.transform(matrix)

Expand Down Expand Up @@ -573,49 +580,8 @@ def align_no_cid(
return result


if __name__ == '__main__':
import argparse

from . import datasets, preprocess


parser = argparse.ArgumentParser('对齐指定的数据集生成新的汇总数据集')
parser.add_argument(
'-e',
'--embedding-size',
type=int,
default=10,
help='用于对齐多音字的字向量长度'
)
parser.add_argument(
'--prefix',
default='aligned',
help='对齐后的数据集输出路径前缀'
)
parser.add_argument(
'--charmap-output',
default='charmap.csv',
help='新旧字 ID 映射表输出文件'
)
parser.add_argument(
'--char-output',
default=os.path.join('aligned', '.characters'),
help='对齐后的新字 ID 到各数据集的原字 ID 的映射文件'
)
parser.add_argument(
'--dialect-output',
default=os.path.join('aligned', '.dialects'),
help='合并各数据集的方言信息文件'
)
parser.add_argument(
'datasets',
nargs='*',
default=('CCR', 'MCPDict'),
help='要对齐的数据集列表'
)
args = parser.parse_args()

logging.getLogger().setLevel(logging.INFO)
def main(args: argparse.Namespace) -> None:
"""对齐指定数据集的多音字"""

names = []
dialects = []
Expand All @@ -635,7 +601,7 @@ def align_no_cid(
else:
nocid.append(data)

logging.info(
logger.info(
f'align datasets {", ".join(names)}, '
f'embedding size = {args.embedding_size}...'
)
Expand All @@ -646,7 +612,7 @@ def align_no_cid(

# 对齐多音字,生成旧字 ID 到新字 ID 的映射表
names = [d.name for d, _ in withcid]
logging.info(f'align datasets {", ".join(names)}...')
logger.info(f'align datasets {", ".join(names)}...')
char_lists = align(*withcid, emb_size=args.embedding_size)

# 整合成一个总的新旧字 ID 映射表
Expand All @@ -664,7 +630,10 @@ def align_no_cid(
.groupby('label')[['simplified', 'traditional']].first()
chars.rename_axis('cid', inplace=True)

logging.info(f'annotate datasets without characer ID {", ".join([d.name for d in nocid])}...')
logger.info(
f'annotate datasets without characer ID: '
f'{", ".join([d.name for d in nocid])} ...'
)
base = pandas.concat(
[preprocess.transform(
d.dropna(subset=['cid', 'did']).replace({'cid': c['label']}),
Expand Down Expand Up @@ -714,17 +683,22 @@ def align_no_cid(
)

path = os.path.abspath(args.charmap_output)
logging.info(f'save {charmap.shape[0]} character mapping to {path}...')
logger.info(f'save {charmap.shape[0]} character mapping to {path}...')
os.makedirs(os.path.dirname(path), exist_ok=True)
charmap.to_csv(path, encoding='utf-8', lineterminator='\n')

path = os.path.abspath(args.char_output)
logging.info(f'save {chars.shape[0]} character information to {path}...')
# 优先使用繁体字形,无繁体的用简体补足
chars['character'] = chars['traditional'].where(
chars['traditional'].notna(),
chars['simplified']
)
path = os.path.abspath(args.character_output)
logger.info(f'save {chars.shape[0]} character information to {path}...')
os.makedirs(os.path.dirname(path), exist_ok=True)
chars.to_csv(path, encoding='utf-8', lineterminator='\n')

path = os.path.abspath(args.dialect_output)
logging.info(f'save {dialects.shape[0]} dialect information to {path}...')
logger.info(f'save {dialects.shape[0]} dialect information to {path}...')
os.makedirs(os.path.dirname(path), exist_ok=True)
dialects.to_csv(path, encoding='utf-8', lineterminator='\n')

Expand Down Expand Up @@ -753,7 +727,7 @@ def align_no_cid(
dirname = os.path.abspath(os.path.join(args.prefix, dataset.name))
path = os.path.join(dirname, str(did))
os.makedirs(dirname, exist_ok=True)
logging.info(f'save {data.shape[0]} aligned data to {path}...')
logger.info(f'save {data.shape[0]} aligned data to {path}...')
data.to_csv(path, index=False, encoding='utf-8', lineterminator='\n')

# 为不含字 ID 的数据集加上字 ID 并保存
Expand All @@ -776,5 +750,97 @@ def align_no_cid(
dirname = os.path.abspath(os.path.join(args.prefix, dataset.name))
path = os.path.join(dirname, str(did))
os.makedirs(dirname, exist_ok=True)
logging.info(f'save {data.shape[0]} aligned data to {path}...')
logger.info(f'save {data.shape[0]} aligned data to {path}...')
data.to_csv(path, index=False, encoding='utf-8', lineterminator='\n')

def evaluate(args: argparse.Namespace) -> None:
"""
使用单个数据集评测多音字对齐准确率

把数据集按方言点随机分成两份,假设两者间字的对应关系未知,应用字对齐算法,
计算对齐结果相对于真实的准确率。
"""

if (dataset := datasets.get(args.datasets[0])) is None:
return

logger.info(f'evaluating alignment accuracy of polyphones for {dataset.name} ...')

# 把数据集按方言点随机分成两份
dids1, dids2 = sklearn.model_selection.train_test_split(
dataset.dialect_ids,
test_size=0.5
)
data1, data2 = dataset.filter(dids1), dataset.filter(dids2)

chars = dataset.characters['character']
chars1, chars2 = align((data1, chars), (data2, chars))

# 统计对齐准确率
cids = chars[chars.duplicated(False)].index \
.intersection(chars1.index) \
.intersection(chars2.index)
labels1 = chars1.loc[cids, 'label']
labels2 = chars2.loc[cids, 'label']
acc = labels1 == labels2

if not acc.all():
for i, r in dataset.characters.loc[cids[~acc.values]] \
.assign(label1=labels1, label2=labels2) \
.sort_values('character') \
.iterrows():
logger.info(
f'bad case: {i}: {r["character"]}, '
f'label1 = {r["label1"]}, label2 = {r["label2"]}'
)

acc = acc.astype(int)
print(f'{dataset.name}: accuracy = {acc.mean()}({acc.sum()}/{acc.count()})')


if __name__ == '__main__':
parser = argparse.ArgumentParser('对齐指定的数据集生成新的汇总数据集')
parser.add_argument('-l', '--log-level', default='WARNING', help='日志级别')
parser.add_argument(
'-e',
'--evaluate',
default=False,
action='store_true',
help='使用单个数据集评测多音字对齐准确率'
)
parser.add_argument(
'-n',
'--embedding-size',
type=int,
default=10,
help='用于对齐多音字的字向量长度'
)
parser.add_argument(
'--prefix',
default='aligned',
help='对齐后的数据集输出路径前缀'
)
parser.add_argument(
'--charmap-output',
default='charmap.csv',
help='新旧字 ID 映射表输出文件'
)
parser.add_argument(
'--character-output',
default=os.path.join('aligned', '.characters'),
help='对齐后的新字 ID 到各数据集的原字 ID 的映射文件'
)
parser.add_argument(
'--dialect-output',
default=os.path.join('aligned', '.dialects'),
help='合并各数据集的方言信息文件'
)
parser.add_argument('datasets', nargs='+', help='要对齐的数据集列表')
args = parser.parse_args()

logger.setLevel(getattr(logging, args.log_level.upper()))

if args.evaluate:
evaluate(args)
else:
main(args)
31 changes: 20 additions & 11 deletions src/sincomp/auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,40 @@
import sklearn.feature_extraction.text


def make_dict(data, minfreq=None, sort=None):
def make_dict(
data: numpy.ndarray[str],
min_frequency: int | float | None = None,
max_categories: int | None = None,
sort: str | None = None
) -> pandas.Series:
"""
根据方言数据构建词典.

Parameters:
data (array-like): 读音样本数据中的一列,空字符串代表缺失值
minfreq (float or int): 出现频次不小于该值才计入词典
sort (str): 指定返回的词典排列:
data: 读音样本数据中的一列
min_frequency: 出现频次不小于该值才计入词典
max_categories: 只取频次最高的若干个类别
sort: 指定返回的词典排列:
- value: 按符号的字典序
- frequency: 按出现频率从大到小

Returns:
dic (`pandas.Series`): 构建的词典,索引为 data 中出现的符号,值为出现频率
dic: 构建的词典,索引为 data 中出现的符号,值为出现频率
"""

data = pandas.Series(data)
dic = data[data != ''].value_counts().rename('frequency')

if minfreq is not None:
# 如果 minfreq 是实数,指定出现的最小比例
if isinstance(minfreq, float):
minfreq = int(minfreq * len(data))
if min_frequency is not None:
# 如果 min_frequency 是实数,指定出现的最小比例
if isinstance(min_frequency, float):
min_frequency = int(min_frequency * len(data))

if minfreq > 1:
dic = dic[dic >= minfreq]
if min_frequency > 1:
dic = dic[dic >= min_frequency]

if max_categories is not N 5324 one:
dic = dic.sort_values(ascending=False)[:max_categories]

# 按指定的方式排序
if sort == 'value':
Expand Down
Loading
0