提交 3d19bc8f 作者: ctt

修改文件路径

上级 36594c06
...@@ -127,7 +127,7 @@ def summary(text, summary_length): ...@@ -127,7 +127,7 @@ def summary(text, summary_length):
# zh_nlp = stanza.Pipeline('zh-hans') # zh_nlp = stanza.Pipeline('zh-hans')
# en_nlp = stanza.Pipeline('en') # en_nlp = stanza.Pipeline('en')
# nlp_dict = {'zh': zh_nlp, 'en': en_nlp} # nlp_dict = {'zh': zh_nlp, 'en': en_nlp}
#model = KeyedVectors.load_word2vec_format(os.path.join(BASE_DIR, 'static/platform_base/Tencent_AILab_ChineseEmbedding.bin'), binary=True) #model = KeyedVectors.load_word2vec_format(os.path.join(BASE_DIR, 'static/base/Tencent_AILab_ChineseEmbedding.bin'), binary=True)
# if __name__ == '__main__': # if __name__ == '__main__':
# print(word_cut('汤姆生病了。他去了医院。')) # print(word_cut('汤姆生病了。他去了医院。'))
# print(word_pos('汤姆生病了。他去了医院。')) # print(word_pos('汤姆生病了。他去了医院。'))
......
...@@ -124,7 +124,7 @@ class BaseDataProcess: ...@@ -124,7 +124,7 @@ class BaseDataProcess:
count = 0 count = 0
if self.embedding_config['use_Tencent']: if self.embedding_config['use_Tencent']:
model_tencent = gensim.models.KeyedVectors.load_word2vec_format( model_tencent = gensim.models.KeyedVectors.load_word2vec_format(
os.path.join(BASE_DIR, 'static/platform_base/Tencent_AILab_ChineseEmbedding.bin'), binary=True) os.path.join(BASE_DIR, 'static/base/Tencent_AILab_ChineseEmbedding.bin'), binary=True)
vocabulary_tencent = model_tencent.wv.vocab.keys() vocabulary_tencent = model_tencent.wv.vocab.keys()
vector_matrix = np.zeros((len(feature_words), int(self.embedding_config['size']) + 200)) vector_matrix = np.zeros((len(feature_words), int(self.embedding_config['size']) + 200))
for word in feature_words: for word in feature_words:
......
...@@ -27,7 +27,7 @@ class DataArguments: ...@@ -27,7 +27,7 @@ class DataArguments:
@dataclass @dataclass
class ModelArguments: class ModelArguments:
model_name_or_path: str = field(default="ernie-3.0-platform_base-zh", metadata={"help": "Build-in pretrained model name or the path to local model."}) model_name_or_path: str = field(default="ernie-3.0-base-zh", metadata={"help": "Build-in pretrained model name or the path to local model."})
export_type: str = field(default='paddle', metadata={"help": "The type to export. Support `paddle` and `onnx`."}) export_type: str = field(default='paddle', metadata={"help": "The type to export. Support `paddle` and `onnx`."})
...@@ -47,7 +47,7 @@ class FewMultiRunner(BaseRunner.BaseRunner): ...@@ -47,7 +47,7 @@ class FewMultiRunner(BaseRunner.BaseRunner):
self.config_path = config_path self.config_path = config_path
self.config = FewMultiConfig(self.config_path) self.config = FewMultiConfig(self.config_path)
def train(self, logger): def train(self, logger2):
py_path = os.path.abspath(__file__) py_path = os.path.abspath(__file__)
sys.argv = [py_path] sys.argv = [py_path]
print(self.config) print(self.config)
......
...@@ -46,7 +46,7 @@ def merge_para(paras): ...@@ -46,7 +46,7 @@ def merge_para(paras):
return new_paras return new_paras
def filter_stopwords(para): def filter_stopwords(para):
path = os.path.join(BASE_DIR, 'static/platform_base/baidu_stopwords.txt') path = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt')
stopword_list = [k.strip() for k in read_txt(path) if stopword_list = [k.strip() for k in read_txt(path) if
k.strip() != ''] k.strip() != '']
words = [word for word in jieba.lcut(para) if word not in stopword_list] words = [word for word in jieba.lcut(para) if word not in stopword_list]
......
...@@ -88,7 +88,7 @@ def update_config_file(config_path, config_file): ...@@ -88,7 +88,7 @@ def update_config_file(config_path, config_file):
data['data_loader']['dataset_path'] = xlsx_path data['data_loader']['dataset_path'] = xlsx_path
if 'save_fname' in data['runner'].keys(): if 'save_fname' in data['runner'].keys():
data['runner']['save_fpath'] = os.path.join(config_path, data['runner']['save_fname']) data['runner']['save_fpath'] = os.path.join(config_path, data['runner']['save_fname'])
data['data_loader']['stopwords_path'] = os.path.join(BASE_DIR, 'static/platform_base/baidu_stopwords.txt') data['data_loader']['stopwords_path'] = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt')
file_path = os.path.join(config_path, 'config.yaml') file_path = os.path.join(config_path, 'config.yaml')
with open(file_path, 'w') as yaml_file: with open(file_path, 'w') as yaml_file:
......
...@@ -241,7 +241,7 @@ class WordCount: ...@@ -241,7 +241,7 @@ class WordCount:
seg_list_exact = jieba.posseg.cut(string_data, HMM=True) # 精确模式分词+HMM seg_list_exact = jieba.posseg.cut(string_data, HMM=True) # 精确模式分词+HMM
object_list = [] object_list = []
# 去除停用词 # 去除停用词
stopwords_path = os.path.join(BASE_DIR, 'static/platform_base/baidu_stopwords.txt') stopwords_path = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt')
with open(stopwords_path, 'r', encoding='UTF-8') as meaninglessFile: with open(stopwords_path, 'r', encoding='UTF-8') as meaninglessFile:
stopwords = set(meaninglessFile.read().split('\n')) stopwords = set(meaninglessFile.read().split('\n'))
stopwords.add(' ') stopwords.add(' ')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论