提交 0e893d10 作者: ctt

调整代码架构及修复bug

上级 c4e5365c
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/10 14:34
# @Author : 程婷婷
# @FileName: BaseConfig.py
# @Software: PyCharm
import yaml
class BaseConfig:
def __init__(self, config_path):
self._config_path = config_path
self._parsed_file = self.load_config()
def load_config(self):
print(self._config_path)
with open(self._config_path) as yaml_file:
parsed_file = yaml.load(yaml_file, Loader=yaml.FullLoader)
return parsed_file
# if __name__ == '__main__':
# bc = BaseConfig()
# print(bc._parsed_file)
# print(bc.load_config()['data_path'])
# print(bc.load_config()['embedding'])
# print(bc.load_config()['model'])
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 17:04
# @Author : 程婷婷
# @FileName: __init__.py.py
# @Software: PyCharm
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/6/1 9:58
# @Author : 程婷婷
# @FileName: BaseDataLoader.py
# @Software: PyCharm
import pandas as pd
from base.views.config.BaseConfig import BaseConfig
class BaseDataLoader:
def __init__(self, config_path):
self.data_loader_config = BaseConfig(config_path)._parsed_file['data_loader']
def read_file(self):
symbol = self.data_loader_config['dataset_path'].split('.')[-1]
if (symbol == 'xlsx') or (symbol == 'xls'):
df = pd.read_excel(r''+self.data_loader_config['dataset_path'])
elif symbol == '.csv':
df = pd.read_csv(r''+self.data_loader_config['dataset_path'], sep='\t')
else:
print('数据类型错误')
return '数据类型错误'
df.drop_duplicates(subset='content', keep='first', inplace=True)
df.dropna(subset=['content', 'label'], inplace=True)
df = df.reset_index(drop=True)
print('=================执行正文去重和去空之后共有%d条数据=============' % len(df['content']))
return df
def read_stopwords(self):
# 读取停顿词列表
stopword_list = [k.strip() for k in open(self.data_loader_config['stopwords_path'], encoding='utf8').readlines() if
k.strip() != '']
return stopword_list
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/10 15:28
# @Author : 程婷婷
# @FileName: BaseDataProcess.py
# @Software: PyCharm
import re
import jieba
import pickle
import gensim
import logging
import numpy as np
import pandas as pd
from pyhanlp import *
from bs4 import BeautifulSoup
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import mutual_info_classif, SelectPercentile
from base.views.config.BaseConfig import BaseConfig
from base.views.data.BaseDataLoader import BaseDataLoader
from platform_zzsn.settings import BASE_DIR
format = '%(asctime)s %(levelname)s %(pathname)s %(funcName)s %(message)s'
logging.basicConfig(format=format, level=logging.INFO)
class BaseDataProcess:
def __init__(self, config_path):
self.embedding_config = BaseConfig(config_path)._parsed_file['embedding']
self.process_config = BaseConfig(config_path)._parsed_file['data_process']
PerceptronLexicalAnalyzer = JClass('com.hankcs.hanlp.model.perceptron.PerceptronLexicalAnalyzer')
self.pla_segment = PerceptronLexicalAnalyzer()
self.bdl = BaseDataLoader(config_path)
def clean_content(self, content):
bs = BeautifulSoup(content, 'html.parser')
return bs.text
def remove_char(self, content):
# 保留中文、英语字母、数字和标点
graph_filter = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9\s,。\.,?\?!!;;]')
content = graph_filter.sub('', content)
return content
def jieba_tokenizer(self, content):
if self.process_config['use_stopwords']:
stopwords = self.bdl.read_stopwords()
else:
stopwords = []
return ' '.join([word for word in jieba.lcut(content) if word not in stopwords])
def pla_tokenizer(self, content):
words = list(self.pla_segment.analyze(content).toWordArray())
if self.process_config['use_stopwords']:
stopwords = self.bdl.read_stopwords()
else:
stopwords = []
return ' '.join([word for word in words if word not in stopwords])
def save(self, voc, path):
with open(path, 'wb') as voc_file:
pickle.dump(voc, voc_file)
def process(self, data, min_content=0):
processed_data = []
for record in data:
record = self.clean_content(str(record))
record = self.remove_char(record)
if len(record) > min_content:
methods = self.process_config['tokenizer']
if methods == 'PerceptronLexicalAnalyzer':
record = self.pla_tokenizer(record)
record = [row.strip() for row in record if row.strip() != '']
else:
record = self.jieba_tokenizer(record)
record = [row.strip() for row in record if row.strip() != '']
processed_data.append(' '.join(record))
else:
pass
return processed_data
def split_dataset(self, data, use_dev):
if use_dev:
train_data_set, test_dev_set = train_test_split(data,
train_size=self.process_config['train_size'],
random_state=self.process_config['random_state'],
shuffle=True)
train_data_set, test_data_set, dev_data_set = train_test_split(test_dev_set,
test_size=self.process_config['test_size'],
random_state=self.process_config['random_state'],
shuffle=True)
print(len(train_data_set) + len(test_data_set) + len(dev_data_set))
return train_data_set, test_data_set, dev_data_set
else:
train_data_set, test_data_set = train_test_split(data,
train_size=self.process_config['train_size'],
random_state=self.process_config['random_state'],
shuffle=True)
return train_data_set, test_data_set
def bag_of_words(self, data, label):
vectorizer = CountVectorizer(ngram_range=(1, 1), min_df=5)
x = vectorizer.fit_transform(data)
transformer = TfidfTransformer(norm=self.embedding_config['norm'], use_idf=self.embedding_config['use_idf'],
smooth_idf=self.embedding_config['smooth_idf'])
x = transformer.fit_transform(x).toarray()
if self.embedding_config['with_feature_selection']:
transformed_data = SelectPercentile(mutual_info_classif, 20).fit_transform(x, label)
else:
transformed_data = x
os.makedirs(self.embedding_config['embedding_path'], exist_ok=True)
self.save(voc=vectorizer.vocabulary_, path=os.path.join(self.embedding_config['embedding_path'], 'tfidf.pkl'))
return transformed_data, vectorizer.get_feature_names()
def word2vec(self, data, feature_words):
model = gensim.models.word2vec.Word2Vec(sentences=data,
size=self.embedding_config['size'],
window=self.embedding_config['window'],
min_count=self.embedding_config['min_count'],
workers=self.embedding_config['workers'],
sg=self.embedding_config['sg'],
iter=self.embedding_config['iter'])
vocabulary_w2v = model.wv.vocab.keys()
count = 0
if self.embedding_config['use_Tencent']:
model_tencent = gensim.models.KeyedVectors.load_word2vec_format(
os.path.join(BASE_DIR, 'static/base/Tencent_AILab_ChineseEmbedding.bin'), binary=True)
vocabulary_tencent = model_tencent.wv.vocab.keys()
vector_matrix = np.zeros((len(feature_words), int(self.embedding_config['size']) + 200))
for word in feature_words:
if word in vocabulary_tencent:
vector_tencent = model_tencent.wv.word_vec(word)
else:
vector_tencent = np.random.randn(200)
if word in vocabulary_w2v:
vector_w2v = model.wv.word_vec(word)
else:
vector_w2v = np.random.randn(self.embedding_config['size'])
vector = np.concatenate((vector_tencent, vector_w2v))
vector_matrix[count] = vector
count += 1
else:
vector_matrix = np.zeros((len(feature_words), self.embedding_config['size']))
for word in feature_words:
if word in vocabulary_w2v:
vector_w2v = model.wv.word_vec(word)
else:
vector_w2v = np.random.randn(self.embedding_config['size'])
vector_matrix[count] = vector_w2v
count += 1
os.makedirs(self.embedding_config['embedding_path'], exist_ok=True)
model.save(os.path.join(self.embedding_config['embedding_path'], 'word2vec.model'))
return vector_matrix
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 17:04
# @Author : 程婷婷
# @FileName: __init__.py.py
# @Software: PyCharm
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 16:30
# @Author : 程婷婷
# @FileName: BaseEvaluator.py
# @Software: PyCharm
from sklearn.metrics import precision_score, f1_score, recall_score, classification_report
import logging
from base.views.config.BaseConfig import BaseConfig
formats = '%(asctime)s %(levelname)s %(pathname)s %(funcName)s %(message)s'
logging.basicConfig(format=formats, level=logging.INFO)
class BaseEvaluator:
def __init__(self, config_path):
self.evaluate_config = BaseConfig(config_path)._parsed_file['evaluate']
def evaluate(self, y_true, y_pred, label_mapping, logger):
result = []
y_true = list(map(str, y_true))
y_pred = list(map(str, y_pred))
logger.info('模型评估结果如下:')
if not label_mapping:
result.append(classification_report(y_true, y_pred))
logger.info(classification_report(y_true, y_pred))
else:
for value in label_mapping.values():
print([k for k,v in label_mapping.items() if v == value])
p = precision_score(y_true, y_pred, average=self.evaluate_config['average'], pos_label=str(value))
r = recall_score(y_true, y_pred, average=self.evaluate_config['average'], pos_label=str(value))
f1 = f1_score(y_true, y_pred, average=self.evaluate_config['average'], pos_label=str(value))
print({'value': value,'召回率为': r, '精确率为': p, 'F1': f1})
logger.info('标签为%s' % [k for k,v in label_mapping.items() if v == value][0])
logger.info('精确率为%.2f' %p)
logger.info('召回率为%.2f' %r)
logger.info('精确率为%.2f' %f1)
result.append(str({'label': value,'recall': r, 'precision': p, 'F1': f1}))
return ' '.join(result)
# y_true = [0, 1, 2, 0, 1, 2]
# y_pred = [0, 2, 1, 0, 0, 1]
# print(BaseEvaluator())
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 17:04
# @Author : 程婷婷
# @FileName: __init__.py.py
# @Software: PyCharm
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 16:29
# @Author : 程婷婷
# @FileName: BaseLoss.py
# @Software: PyCharm
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 17:04
# @Author : 程婷婷
# @FileName: __init__.py.py
# @Software: PyCharm
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 16:18
# @Author : 程婷婷
# @FileName: BaseModel.py
# @Software: PyCharm
from base.views.config.BaseConfig import BaseConfig
import os
import pickle
class BaseModel:
def __init__(self,config_path):
self.model_config = BaseConfig(config_path)._parsed_file['model']
def building_model(self, *params):
pass
def save(self, model):
dir = os.path.dirname(self.model_config['model_path'])
if not os.path.exists(dir):
os.makedirs(dir)
with open(self.model_config['model_path'], 'wb') as model_file:
pickle.dump(model, model_file)
def predict(self, model, X):
proba = model.predict_proba(X)
y_predict = model.predict(X)
return {'proba': proba, 'y_predict': y_predict}
\ No newline at end of file
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 17:04
# @Author : 程婷婷
# @FileName: __init__.py.py
# @Software: PyCharm
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 16:36
# @Author : 程婷婷
# @FileName: BaseRunner.py
# @Software: PyCharm
from base.views.config.BaseConfig import BaseConfig
class BaseRunner:
def __init__(self,config_path):
self.runner_config = BaseConfig(config_path)._parsed_file['runner']
def train(self, logger):
pass
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 17:04
# @Author : 程婷婷
# @FileName: __init__.py.py
# @Software: PyCharm
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/5/11 9:24
# @Author : 程婷婷
# @FileName: test.py
# @Software: PyCharm
import jieba
import re
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import load_digits
from sklearn.feature_selection import SelectPercentile, chi2
X, y = load_digits(return_X_y=True)
print(X.shape)
print(X[:10], y[:100])
X_new = SelectPercentile(chi2, percentile=10).fit_transform(X, y)
print(X_new.shape)
print(X_new[:10])
...@@ -15,9 +15,9 @@ import ahocorasick ...@@ -15,9 +15,9 @@ import ahocorasick
import pandas as pd import pandas as pd
from gensim.models.keyedvectors import KeyedVectors from gensim.models.keyedvectors import KeyedVectors
from platform_zzsn.settings import BASE_DIR from platform_zzsn.settings import BASE_DIR
from model.base.views import utils from platform_base.views import base_utils
General_dict = utils.read_txt(os.path.join(BASE_DIR, 'static/base/dict_sogou.txt')) General_dict = base_utils.read_txt(os.path.join(BASE_DIR, 'static/base/dict_sogou.txt'))
General_dict_ = '' General_dict_ = ''
for key in General_dict: for key in General_dict:
General_dict_ += ' ' + str(key.strip()) General_dict_ += ' ' + str(key.strip())
...@@ -127,7 +127,7 @@ def summary(text, summary_length): ...@@ -127,7 +127,7 @@ def summary(text, summary_length):
# zh_nlp = stanza.Pipeline('zh-hans') # zh_nlp = stanza.Pipeline('zh-hans')
# en_nlp = stanza.Pipeline('en') # en_nlp = stanza.Pipeline('en')
# nlp_dict = {'zh': zh_nlp, 'en': en_nlp} # nlp_dict = {'zh': zh_nlp, 'en': en_nlp}
#model = KeyedVectors.load_word2vec_format(os.path.join(BASE_DIR, 'static/base/Tencent_AILab_ChineseEmbedding.bin'), binary=True) #model = KeyedVectors.load_word2vec_format(os.path.join(BASE_DIR, 'static/platform_base/Tencent_AILab_ChineseEmbedding.bin'), binary=True)
# if __name__ == '__main__': # if __name__ == '__main__':
# print(word_cut('汤姆生病了。他去了医院。')) # print(word_cut('汤姆生病了。他去了医院。'))
# print(word_pos('汤姆生病了。他去了医院。')) # print(word_pos('汤姆生病了。他去了医院。'))
......
import shutil
from tkinter import _flatten from tkinter import _flatten
from django.http import JsonResponse from django.http import JsonResponse
from django.views.decorators.http import require_POST from django.views.decorators.http import require_POST
from basic_service.views import basic, co_occurrence from basic_service.views import basic, co_occurrence
from model.base.views.token_authorize import * from platform_base.views.token_authorize import *
import shutil from platform_zzsn.settings import BASE_DIR
UPLOAD_FOLDER = '/home/zzsn/ctt/platform_zzsn/media/' UPLOAD_FOLDER = os.path.join(BASE_DIR, 'media/')
# Create your views here. # Create your views here.
@require_POST @require_POST
......
from django.apps import AppConfig from django.apps import AppConfig
class BaseConfig(AppConfig):
name = 'base'
from django.db import models from django.db import models
from datetime import datetime
# Create your models here.
class User(models.Model):
username = models.CharField(max_length=30, unique=True)
true_name = models.CharField(max_length=30)
sex = models.CharField(max_length=2)
mobile_number = models.CharField(max_length=20)
mail = models.CharField(max_length=20)
id_card = models.CharField(max_length=20)
password = models.CharField(max_length=40)
account_number = models.CharField(max_length=20)
def toDict(self):
return {'id':self.id,
'username':self.username,
'true_name':self.true_name,
'sex':self.sex,
'mobile_number':self.mobile_number,
'mail':self.mail,
'id_card':self.id_card,
'password':self.password,
'account_number':self.account_number,
# 'update_at':self.update_at.strftime('%Y-%m-%d %H:%M:%S')
}
class Meta:
db_table = 'user'
class ServiceManage(models.Model):
name = models.CharField(max_length=15)
username = models.CharField(max_length=30)
filenames = models.CharField(max_length=200)
create_date = models.DateTimeField(default=None)
end_date = models.DateTimeField(default=None)
state = models.CharField(max_length=10)
path = models.CharField(max_length=20)
def toDict(self):
return {'name': self.name,
'username': self.username,
'filenames': self.filenames,
'create_date': self.create_date.strftime('%Y-%m-%d %H:%M:%S'),
'end_date': self.end_date.strftime('%Y-%m-%d %H:%M:%S'),
'state': self.state,
'path': self.path,
}
class Meta:
db_table = 'service_manage'
class SubjectManage(models.Model):
sid = models.CharField(max_length=10, unique=True)
name = models.CharField(max_length=30)
def toDict(self):
return {'sid': self.sid,
'name': self.name,
}
class Meta:
db_table = 'subject_manage'
class ModelManage(models.Model):
task_name = models.CharField(max_length=30)
function_type = models.CharField(max_length=20)
model_type = models.CharField(max_length=20)
version_num = models.IntegerField()
create_date = models.DateTimeField(default=None)
def toDict(self):
return {'id': self.id,
'task_name': self.task_name,
'function_type': self.function_type,
'model_type': self.model_type,
'version_num': self.version_num,
'create_date': self.create_date.strftime('%Y-%m-%d %H:%M:%S'),
}
class Meta:
db_table = 'model_manage'
class VersionManage(models.Model):
model = models.ForeignKey(ModelManage, related_name='version_model', on_delete=models.CASCADE)
version = models.CharField(max_length=20)
create_date = models.DateTimeField(default=None)
end_date = models.DateTimeField(default=None)
state = models.CharField(max_length=20)
creator = models.CharField(max_length=30)
path = models.CharField(max_length=20, unique=True)
def toDict(self):
return {'id': self.id,
'version': self.version,
'create_date': self.create_date.strftime('%Y-%m-%d %H:%M:%S'),
'end_date': self.end_date.strftime('%Y-%m-%d %H:%M:%S'),
'state': self.state,
'creator': self.creator,
'path': self.path,
}
class Meta:
db_table = 'version_manage'
\ No newline at end of file
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/8/12 18:05
# @Author : 程婷婷
# @FileName: urls.py
# @Software: PyCharm
from model.base.views import views as base_views
from django.conf.urls import url
urlpatterns = [
url(r'^register-account', base_views.register_account, name='register_account'),
url(r'^verify-username', base_views.verify_username, name='verify_username'),
url(r'^login', base_views.login, name='login'),
url(r'^reset-password', base_views.reset_password, name='reset_password'),
url(r'^show-config-file', base_views.show_config_file, name='show_config_file'),
url(r'^show-service-file', base_views.show_service_file, name='show_service_file'),
url(r'^delete-file-row-manage', base_views.delete_file_row_manage, name='delete_file_row_manage'),
url(r'^delete-file-row-service', base_views.delete_file_row_service, name='delete_file_row_service'),
url(r'^file-upload', base_views.file_upload, name='file_upload'),
url(r'^show-log-file', base_views.show_log_file, name='show_log_file'),
url(r'^validate-code', base_views.validate_code, name='validate_code'),
url(r'^download-zip', base_views.download_zip, name='download_zip'),
url(r'^download-xlsx', base_views.download_xlsx, name='download_xlsx'),
url(r'^query-manage', base_views.query_manage, name='query_manage'),
url(r'^forget-password', base_views.forget_password, name='forget_password'),
url(r'^train', base_views.run_train, name='train'),
url(r'^query-service-manage', base_views.query_service_manage, name='query_service_manage'),
url(r'^query-subject', base_views.query_subject, name='query_subject'),
url(r'^query-version', base_views.query_version, name='query_version'),
url(r'^query-task-name', base_views.query_task_name, name='query_task_name')
]
...@@ -22,14 +22,13 @@ from platform_zzsn.settings import BASE_DIR ...@@ -22,14 +22,13 @@ from platform_zzsn.settings import BASE_DIR
format = '%(asctime)s %(levelname)s %(pathname)s %(funcName)s %(message)s' format = '%(asctime)s %(levelname)s %(pathname)s %(funcName)s %(message)s'
logging.basicConfig(format=format, level=logging.INFO) logging.basicConfig(format=format, level=logging.INFO)
PerceptronLexicalAnalyzer = JClass('com.hankcs.hanlp.model.perceptron.PerceptronLexicalAnalyzer')
pla_segment = PerceptronLexicalAnalyzer()
class BaseDataProcess: class BaseDataProcess:
def __init__(self, config_path): def __init__(self, config_path):
self.embedding_config = BaseConfig.BaseConfig(config_path)._parsed_file['embedding'] self.embedding_config = BaseConfig.BaseConfig(config_path)._parsed_file['embedding']
self.process_config = BaseConfig.BaseConfig(config_path)._parsed_file['data_process'] self.process_config = BaseConfig.BaseConfig(config_path)._parsed_file['data_process']
PerceptronLexicalAnalyzer = JClass('com.hankcs.hanlp.model.perceptron.PerceptronLexicalAnalyzer') self.pla_segment = pla_segment
self.pla_segment = PerceptronLexicalAnalyzer()
self.bdl = BaseDataLoader.BaseDataLoader(config_path) self.bdl = BaseDataLoader.BaseDataLoader(config_path)
def clean_content(self, content): def clean_content(self, content):
...@@ -125,7 +124,7 @@ class BaseDataProcess: ...@@ -125,7 +124,7 @@ class BaseDataProcess:
count = 0 count = 0
if self.embedding_config['use_Tencent']: if self.embedding_config['use_Tencent']:
model_tencent = gensim.models.KeyedVectors.load_word2vec_format( model_tencent = gensim.models.KeyedVectors.load_word2vec_format(
os.path.join(BASE_DIR, 'static/base/Tencent_AILab_ChineseEmbedding.bin'), binary=True) os.path.join(BASE_DIR, 'static/platform_base/Tencent_AILab_ChineseEmbedding.bin'), binary=True)
vocabulary_tencent = model_tencent.wv.vocab.keys() vocabulary_tencent = model_tencent.wv.vocab.keys()
vector_matrix = np.zeros((len(feature_words), int(self.embedding_config['size']) + 200)) vector_matrix = np.zeros((len(feature_words), int(self.embedding_config['size']) + 200))
for word in feature_words: for word in feature_words:
......
import os
import yaml
import random
import smtplib
from email.mime.text import MIMEText
from django.core.paginator import Paginator
from email.mime.multipart import MIMEMultipart
from PIL import Image,ImageFont,ImageDraw,ImageFilter
from model.base.models import ModelManage, ServiceManage, VersionManage
from platform_zzsn.settings import BASE_DIR
class Picture:
def __init__(self):
self.size = (240,60)
self.mode='RGB'
self.color='white'
self.font = ImageFont.truetype(os.path.join(BASE_DIR,
'static/common/font/arial.ttf'), 36) #设置字体大小
def randChar(self):
basic='23456789abcdefghijklmnpqrstwxyzABCDEFGHIJKLMNPQRSTWXYZ'
return basic[random.randint(0,len(basic)-1)] #随机字符
def randBdColor(self):
return (random.randint(64,255),random.randint(64,255),random.randint(64,255)) #背景
def randTextColor(self):
return (random.randint(32, 127), random.randint(32, 127), random.randint(32, 127)) #随机颜色
def proPicture(self):
new_image=Image.new(self.mode,self.size,self.color) #创建新图像有三个默认参数:尺寸,颜色,模式
drawObject=ImageDraw.Draw(new_image) #创建一个可以对image操作的对象
line_num = random.randint(4,6) # 干扰线条数
for i in range(line_num):
#size=(240,60)
begin = (random.randint(0, self.size[0]), random.randint(0, self.size[1]))
end = (random.randint(0, self.size[0]), random.randint(0, self.size[1]))
drawObject.line([begin, end], self.randTextColor())
for x in range(240):
for y in range(60):
tmp = random.randint(0,50)
if tmp>30: #调整干扰点数量
drawObject.point((x,y),self.randBdColor())
randchar=''
for i in range(5):
rand=self.randChar()
randchar+=rand
drawObject.text([50*i+10,10],rand,self.randTextColor(),font=self.font) #写入字符
new_image = new_image.filter(ImageFilter.SHARPEN) # 滤镜
return new_image,randchar
def update_config_file(config_path, config_file):
data = yaml.load(config_file, Loader=yaml.FullLoader)
data['data_loader'] = {}
model_path = data['model']['model_path']
model_name = data['model']['model_name']
if data['model']['model_path']:
data['model']['model_path'] = os.path.join(config_path, model_path)
else:
data['model']['model_path'] = os.path.join(config_path, model_name)
print(data['model']['model_path'])
embedding_path = data['embedding']['embedding_path']
if embedding_path:
data['embedding']['embedding_path'] = os.path.join(config_path, data['embedding']['embedding_path'])
else:
if data['embedding']['name']:
data['embedding']['embedding_path'] = os.path.join(config_path, data['embedding']['name'])
tokenizer_path = data['embedding']['tokenizer_path']
if tokenizer_path:
data['embedding']['tokenizer_path'] = os.path.join(config_path, data['embedding']['tokenizer_path'])
try:
test_file_path = data['data_process']['test_file_path']
train_file_path = data['data_process']['train_file_path']
except KeyError:
pass
else:
data['data_process']['test_file_path'] = os.path.join(config_path, test_file_path)
data['data_process']['train_file_path'] = os.path.join(config_path, train_file_path)
for file in os.listdir(config_path):
if ('.xls' == file[-4:]) or ('.xlsx' == file[-5:]):
xlsx_path = os.path.join(config_path, file)
data['data_loader']['dataset_path'] = xlsx_path
if 'save_fname' in data['runner'].keys():
data['runner']['save_fpath'] = os.path.join(config_path, data['runner']['save_fname'])
data['data_loader']['stopwords_path'] = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt')
file_path = os.path.join(config_path, 'config.yaml')
with open(file_path, 'w') as yaml_file:
yaml.safe_dump(data, yaml_file, default_flow_style=False)
return file_path
def select_manage(task_name, function_type, model_type, begin_cdate, end_cdate, page_size, current_page):
condition = {'task_name': task_name, 'function_type': function_type, 'model_type': model_type,
'create_date__range': (begin_cdate, end_cdate,)
}
del_keys = []
for key in condition.keys():
if not condition[key]:
del_keys.append(key)
if not condition['create_date__range'][0]:
del_keys.append('create_date__range')
for key in del_keys:
condition.pop(key)
managers = ModelManage.objects.filter(**condition).order_by('-create_date')
len_managers = len(managers)
page = Paginator(managers, page_size)
maxpages = page.num_pages # 最大页数
pIndex = int(current_page)
# 判断页数是否越界
if pIndex > maxpages:
pIndex = maxpages
manager_list = page.page(pIndex) # 当前页数据
return list(manager_list), len_managers
def select_version(model_id, begin_cdate, end_cdate, page_size, current_page):
condition = {'model_id': model_id,
'create_date__range': (begin_cdate, end_cdate,)
}
del_keys = []
if not condition['create_date__range'][0]:
del_keys.append('create_date__range')
for key in del_keys:
condition.pop(key)
versions = VersionManage.objects.filter(**condition).order_by('-create_date')
len_versions = len(versions)
page = Paginator(versions, page_size)
maxpages = page.num_pages # 最大页数
pIndex = int(current_page)
# 判断页数是否越界
if pIndex > maxpages:
pIndex = maxpages
version_list = page.page(pIndex) # 当前页数据
return list(version_list), len_versions
def select_service_manage(name, begin_cdate, end_cdate, state, username, page_size, current_page):
condition = {
'name': name,
'state': state,
'create_date__range': (begin_cdate, end_cdate),
'username': username,
}
del_keys = []
for key in condition.keys():
if not condition[key]:
del_keys.append(key)
if not condition['create_date__range'][0]:
del_keys.append('create_date__range')
for key in del_keys:
condition.pop(key)
print(condition)
service_managers = ServiceManage.objects.filter(**condition).order_by('-create_date')
len_service_managers = len(service_managers)
page = Paginator(service_managers, page_size)
maxpages = page.num_pages
pIndex = int(current_page)
# 判断页数是否越界
if pIndex > maxpages:
pIndex = maxpages
manager_list = page.page(pIndex) # 当前页数据
return list(manager_list), len_service_managers
def sendMail(user,pwd,sender,receiver,msg_title):
mail_host = "smtp.163.com" #163的SMTP服务器
message = MIMEMultipart('alternative')
#设置邮件的发送者
message["From"] = sender
#设置邮件的接收方
message["To"] = ",".join(receiver)
#4.设置邮件的标题
message["Subject"] = msg_title
# 添加plain格式的文本
# message.attach(MIMEText('您好,\n'
# ' 您当前的密码为%s, 为了保证您的账号安全,请尽快登陆重置您的密码'%msg_content, 'plain', 'utf-8'))
# 添加html内容
message.attach(MIMEText('<html>'
'<body>'
'<h1>Hello </h1><br> '
'<h3>To ensure the security of your account, please log in and reset your password as soon as possible.</h3>'
'<h2><a href="http://192.168.1.149:8020/reset_password/">点此重置</a></h2>'
'</body>'
'</html>', 'html', 'utf-8'))
#1.启用服务器发送邮件
smtpObj = smtplib.SMTP_SSL(mail_host,465)
#2.登录邮箱进行验证
smtpObj.login(user,pwd)
#3.发送邮件
#参数:发送方,接收方,邮件信息
smtpObj.sendmail(sender,receiver,message.as_string())
return True
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/8/20 16:58
# @Author : 程婷婷
# @FileName: token_authorize.py
# @Software: PyCharm
import jwt
import time
import functools
from jwt import exceptions
from django.http import JsonResponse
from platform_zzsn.settings import *
global SECRET_KEY
SECRET_KEY = SECRET_KEY
# 定义签名密钥,用于校验jwt的有效、合法性
def create_token(user):
'''基于jwt创建token的函数'''
headers = {
"alg": "HS256",
"typ": "JWT"
}
exp = int(time.time() + 3*60*60)
payload = {
"id": user.id,
"name": user.username,
"exp": exp
}
token = jwt.encode(payload=payload, key=SECRET_KEY, algorithm='HS256', headers=headers).decode('utf-8')
return token
def login_required(view_func):
@functools.wraps(view_func)
def validate_token(request, *args, **kwargs):
'''校验token的函数,校验通过则返回解码信息'''
payload = None
msg = None
try:
token = request.META.get("HTTP_AUTHORIZATION")
payload = jwt.decode(token, SECRET_KEY, True, algorithm='HS256')
print(payload)
return view_func(request, *args, **kwargs)
# jwt有效、合法性校验
except exceptions.ExpiredSignatureError:
return JsonResponse({
'handle_msg': 'failure',
'is_handle_success': False,
'logs': '登录已过期'
})
except jwt.DecodeError:
return JsonResponse({
'handle_msg': 'failure',
'is_handle_success': False,
'logs': '缺少参数token'
# token认证失败
})
except jwt.InvalidTokenError:
return JsonResponse({
'handle_msg': 'failure',
'is_handle_success': False,
'logs': '缺少参数token'
# 非法的token
})
return validate_token
\ No newline at end of file
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/8/9 11:19
# @Author : 程婷婷
# @FileName: utils.py
# @Software: PyCharm
import os
import re
import jieba
import zipfile
import pandas as pd
from docx import Document
from platform_zzsn.settings import *
def read_txt(path):
with open(path, 'r', encoding='utf8') as file:
lines = file.readlines()
return lines
def read_docx(pending_file, user_file):
jieba.load_userdict(user_file)
document = Document(pending_file)
doc_text_list = []
for para in document.paragraphs:
para_text = re.sub(r'\s', '', para.text)
if para_text:
doc_text_list.append(para_text)
return doc_text_list
def read_excel(pending_file, user_file):
jieba.load_userdict(user_file)
doc_text_list = pd.read_excel(pending_file)['content']
doc_text_list.dropna(inplace=True)
return doc_text_list
def merge_para(paras):
new_paras = []
for i, para in enumerate(paras):
if not new_paras:
new_paras.append(para)
elif (len(new_paras[-1]) < 500):
new_paras[-1] += para
else:
new_paras.append(para)
return new_paras
def filter_stopwords(para):
path = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt')
stopword_list = [k.strip() for k in read_txt(path) if
k.strip() != '']
words = [word for word in jieba.lcut(para) if word not in stopword_list]
return words
# 获取列表的第二个元素
def takeSecond(elem):
return elem[1]
def takeFirst_len(elem):
return len(elem[0])
def make_zip(file_dir: str, zip_path: str) -> None:
zip_f = zipfile.ZipFile(zip_path, 'w')
pre_len = len(os.path.dirname(file_dir))
for parent, dir_names, filenames in os.walk(file_dir):
for filename in filenames:
path_file = os.path.join(parent, filename)
arc_name = path_file[pre_len:].strip(os.path.sep)
zip_f.write(path_file, arc_name)
zip_f.close()
...@@ -2,4 +2,4 @@ from django.apps import AppConfig ...@@ -2,4 +2,4 @@ from django.apps import AppConfig
class BaseConfig(AppConfig): class BaseConfig(AppConfig):
name = 'base' name = 'platform_base'
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
# @FileName: urls.py # @FileName: urls.py
# @Software: PyCharm # @Software: PyCharm
from django.urls import path from django.urls import path
from base.views import views from platform_base.views import views
from django.conf.urls import url from django.conf.urls import url
from base.views import views as base_views from platform_base.views import views as base_views
urlpatterns = [ urlpatterns = [
url(r'^register-account', base_views.register_account, name='register_account'), url(r'^register-account', base_views.register_account, name='register_account'),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# @Time : 2021/8/9 11:19 # @Time : 2021/8/9 11:19
# @Author : 程婷婷 # @Author : 程婷婷
# @FileName: utils.py # @FileName: base_utils.py
# @Software: PyCharm # @Software: PyCharm
import os import os
import re import re
...@@ -46,7 +46,7 @@ def merge_para(paras): ...@@ -46,7 +46,7 @@ def merge_para(paras):
return new_paras return new_paras
def filter_stopwords(para): def filter_stopwords(para):
path = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt') path = os.path.join(BASE_DIR, 'static/platform_base/baidu_stopwords.txt')
stopword_list = [k.strip() for k in read_txt(path) if stopword_list = [k.strip() for k in read_txt(path) if
k.strip() != ''] k.strip() != '']
words = [word for word in jieba.lcut(para) if word not in stopword_list] words = [word for word in jieba.lcut(para) if word not in stopword_list]
......
...@@ -6,7 +6,7 @@ from email.mime.text import MIMEText ...@@ -6,7 +6,7 @@ from email.mime.text import MIMEText
from django.core.paginator import Paginator from django.core.paginator import Paginator
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from PIL import Image,ImageFont,ImageDraw,ImageFilter from PIL import Image,ImageFont,ImageDraw,ImageFilter
from base.models import ModelManage, ServiceManage, VersionManage from platform_base.models import ModelManage, ServiceManage, VersionManage
from platform_zzsn.settings import BASE_DIR from platform_zzsn.settings import BASE_DIR
...@@ -70,7 +70,7 @@ def update_config_file(config_path, config_file): ...@@ -70,7 +70,7 @@ def update_config_file(config_path, config_file):
data['embedding']['embedding_path'] = os.path.join(config_path, data['embedding']['embedding_path']) data['embedding']['embedding_path'] = os.path.join(config_path, data['embedding']['embedding_path'])
else: else:
if data['embedding']['name']: if data['embedding']['name']:
data['embedding']['embedding_path'] = os.path.join(config_path, data['embedding']['name']) data['embedding']['embedding_path'] = data['embedding']['name']
tokenizer_path = data['embedding']['tokenizer_path'] tokenizer_path = data['embedding']['tokenizer_path']
if tokenizer_path: if tokenizer_path:
data['embedding']['tokenizer_path'] = os.path.join(config_path, data['embedding']['tokenizer_path']) data['embedding']['tokenizer_path'] = os.path.join(config_path, data['embedding']['tokenizer_path'])
...@@ -88,7 +88,7 @@ def update_config_file(config_path, config_file): ...@@ -88,7 +88,7 @@ def update_config_file(config_path, config_file):
data['data_loader']['dataset_path'] = xlsx_path data['data_loader']['dataset_path'] = xlsx_path
if 'save_fname' in data['runner'].keys(): if 'save_fname' in data['runner'].keys():
data['runner']['save_fpath'] = os.path.join(config_path, data['runner']['save_fname']) data['runner']['save_fpath'] = os.path.join(config_path, data['runner']['save_fname'])
data['data_loader']['stopwords_path'] = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt') data['data_loader']['stopwords_path'] = os.path.join(BASE_DIR, 'static/platform_base/baidu_stopwords.txt')
file_path = os.path.join(config_path, 'config.yaml') file_path = os.path.join(config_path, 'config.yaml')
with open(file_path, 'w') as yaml_file: with open(file_path, 'w') as yaml_file:
......
...@@ -11,6 +11,7 @@ import logging ...@@ -11,6 +11,7 @@ import logging
import datetime import datetime
import tempfile import tempfile
import zipfile import zipfile
import traceback
from io import BytesIO from io import BytesIO
from django.db import transaction from django.db import transaction
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
...@@ -20,15 +21,17 @@ from django.forms.models import model_to_dict ...@@ -20,15 +21,17 @@ from django.forms.models import model_to_dict
from django.http import JsonResponse, HttpResponse from django.http import JsonResponse, HttpResponse
from django.core.files.storage import default_storage from django.core.files.storage import default_storage
from django.views.decorators.http import require_POST from django.views.decorators.http import require_POST
from base.views import interaction, utils from platform_base.views import interaction, base_utils
from base.views.token_authorize import * from platform_base.views.token_authorize import *
from base.models import User, ModelManage, ServiceManage, SubjectManage, VersionManage from platform_base.models import User, ModelManage, ServiceManage, SubjectManage, VersionManage
from classify.views.textcnn_classify.TextcnnClassifyRunner import TextcnnClassifyRunner from model.classify.views.textcnn_classify.TextcnnClassifyRunner import TextcnnClassifyRunner
from classify.views.xgboost_classify.XgboostClassifyRunner import XgboostClassifyRunner from model.classify.views.xgboost_classify.XgboostClassifyRunner import XgboostClassifyRunner
from classify.views.logistic_classify.LogisticClassifyRunner import LogisticClassifyRunner from model.classify.views.logistic_classify.LogisticClassifyRunner import LogisticClassifyRunner
from classify.views.fasttext_classify.FastTextRunner import FastTextRunner from model.classify.views.few_multi_class.FewMultiClassRunner import FewMultiRunner
# from classify.flair_classify.FlairClassifyRunner import FlairClassifyRunner from model.classify.views.few_multi_label.FewMultiLabelRunner import FewMultiLabelRunner
from clustering.views.KMeans.KmeansRunner import KmeansRunner from model.classify.views.fasttext_classify.FastTextRunner import FastTextRunner
# from model.classify.flair_classify.FlairClassifyRunner import FlairClassifyRunner
from model.clustering.views.KMeans.KmeansRunner import KmeansRunner
from platform_zzsn.settings import BASE_DIR from platform_zzsn.settings import BASE_DIR
print('-----------') print('-----------')
print(BASE_DIR) print(BASE_DIR)
...@@ -173,7 +176,7 @@ def show_config_file(request): ...@@ -173,7 +176,7 @@ def show_config_file(request):
model_type = request.POST['model_type'] model_type = request.POST['model_type']
try: try:
path = os.path.join(BASE_DIR, r'static/common/config_data/'+ model_type + '.yaml') path = os.path.join(BASE_DIR, r'static/common/config_data/'+ model_type + '.yaml')
data = utils.read_txt(path) data = base_utils.read_txt(path)
except Exception as e: except Exception as e:
print(e) print(e)
return JsonResponse({ return JsonResponse({
...@@ -408,7 +411,7 @@ def show_log_file(request): ...@@ -408,7 +411,7 @@ def show_log_file(request):
path = UPLOAD_FOLDER + path_timestamp path = UPLOAD_FOLDER + path_timestamp
files = [filename for filename in os.listdir(path) if 'log' in filename] files = [filename for filename in os.listdir(path) if 'log' in filename]
log_path = os.path.join(path, files[0]) log_path = os.path.join(path, files[0])
data = utils.read_txt(log_path) data = base_utils.read_txt(log_path)
except Exception as e: except Exception as e:
print(e) print(e)
return JsonResponse({ return JsonResponse({
...@@ -448,6 +451,7 @@ def validate_code(request): ...@@ -448,6 +451,7 @@ def validate_code(request):
@require_POST @require_POST
@login_required @login_required
@transaction.atomic @transaction.atomic
# @transaction.non_atomic_requests
def run_train(request): def run_train(request):
token = request.META.get("HTTP_AUTHORIZATION") token = request.META.get("HTTP_AUTHORIZATION")
task_name = request.POST['task_name'] task_name = request.POST['task_name']
...@@ -484,8 +488,8 @@ def run_train(request): ...@@ -484,8 +488,8 @@ def run_train(request):
create_date=create_time, create_date=create_time,
) )
model_id = max(ModelManage.objects.values_list('id', flat=True)) model_id = max(ModelManage.objects.values_list('id', flat=True))
else: # else:
model_manage = ModelManage.objects.get(id=model_id) # model_manage = ModelManage.objects.get(id=model_id)
if not new_version: if not new_version:
versions = VersionManage.objects.filter(model_id=model_id) versions = VersionManage.objects.filter(model_id=model_id)
new_version = max([int(version.version.replace('V', '')) for version in versions])+1 new_version = max([int(version.version.replace('V', '')) for version in versions])+1
...@@ -504,14 +508,17 @@ def run_train(request): ...@@ -504,14 +508,17 @@ def run_train(request):
'logistic': LogisticClassifyRunner(config_path), 'logistic': LogisticClassifyRunner(config_path),
# 'flair': FlairClassifyRunner(config_path), # 'flair': FlairClassifyRunner(config_path),
'textcnn': TextcnnClassifyRunner(config_path), 'textcnn': TextcnnClassifyRunner(config_path),
'kmeans': KmeansRunner(config_path)} 'kmeans': KmeansRunner(config_path),
'few_multi_class': FewMultiRunner(config_path),
'few_multi_label': FewMultiLabelRunner(config_path)}
train_dict[model_type].train(logger) train_dict[model_type].train(logger)
end_time = datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S') end_time = datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S')
version_manage.end_date = end_time # version_manage.end_date = end_time
version_manage.state = '训练成功' # version_manage.state = '训练成功'
version_manage.save() # version_manage.save()
model_manage.version_num = int(version_num) + 1 VersionManage.objects.filter(model_id=model_id).update(end_date = end_time, state = '训练成功')
model_manage.save() ModelManage.objects.filter(id=model_id).update(version_num = int(version_num) + 1)
# model_manage.save()
return JsonResponse({ return JsonResponse({
'token': token, 'token': token,
'handleMsg': 'success', 'handleMsg': 'success',
...@@ -529,7 +536,7 @@ def run_train(request): ...@@ -529,7 +536,7 @@ def run_train(request):
'token': token, 'token': token,
'handleMsg': 'failure', 'handleMsg': 'failure',
'isHandleSuccess': False, 'isHandleSuccess': False,
'logs': str(e), 'logs': str(traceback.format_exc()),
'resultData': False, 'resultData': False,
}) })
finally: finally:
......
...@@ -37,6 +37,7 @@ INSTALLED_APPS = [ ...@@ -37,6 +37,7 @@ INSTALLED_APPS = [
'django.contrib.sessions', 'django.contrib.sessions',
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'platform_base',
'model.base', 'model.base',
'basic_service', 'basic_service',
'model.classify', 'model.classify',
......
...@@ -16,14 +16,12 @@ Including another URLconf ...@@ -16,14 +16,12 @@ Including another URLconf
from django.urls import include, path from django.urls import include, path
import basic_service.urls import basic_service.urls
import scenario_service.urls import scenario_service.urls
import model.base.urls import platform_base.urls
urlpatterns = [ urlpatterns = [
# path('admin/', admin.site.urls), # path('admin/', admin.site.urls),
path('basic/', include(basic_service.urls)), path('basic/', include(basic_service.urls)),
# path('classify/', include(classify.urls)), path('base/', include(platform_base.urls)),
# path('clustering/', include(clustering.urls)),
path('base/', include(model.base.urls)),
path('scenario/', include(scenario_service.urls)), path('scenario/', include(scenario_service.urls)),
] ]
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# coding:utf-8 # coding:utf-8
from sklearn.feature_extraction.text import TfidfTransformer from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import CountVectorizer
from model.base.views.utils import * from platform_base.views.base_utils import *
def cv_tfidf(corpus): def cv_tfidf(corpus):
......
...@@ -11,7 +11,7 @@ import pandas as pd ...@@ -11,7 +11,7 @@ import pandas as pd
from collections import Counter from collections import Counter
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from scenario_service.views import cv_tfidf from scenario_service.views import cv_tfidf
from model.base.views import utils from platform_base.views import base_utils
def post_project_info(url, title, content): def post_project_info(url, title, content):
...@@ -86,15 +86,15 @@ def post_stock_recruitment_predict(url, file_name): ...@@ -86,15 +86,15 @@ def post_stock_recruitment_predict(url, file_name):
def cv_tfidf_keywords(download_path, pending_file, user_file): def cv_tfidf_keywords(download_path, pending_file, user_file):
file_type = pending_file.split('.')[-1] file_type = pending_file.split('.')[-1]
if (file_type == 'docx') or (file_type == 'doc'): if (file_type == 'docx') or (file_type == 'doc'):
doc_text_list = utils.read_docx(pending_file, user_file) doc_text_list = base_utils.read_docx(pending_file, user_file)
doc_text_list = utils.merge_para(doc_text_list) doc_text_list = base_utils.merge_para(doc_text_list)
else: else:
# print('运行xlsx文件') # print('运行xlsx文件')
doc_text_list = utils.read_excel(pending_file, user_file) doc_text_list = base_utils.read_excel(pending_file, user_file)
# print(doc_text_list) # print(doc_text_list)
corpus, all_words, = [], [] corpus, all_words, = [], []
for para in doc_text_list: for para in doc_text_list:
words = utils.filter_stopwords(para) words = base_utils.filter_stopwords(para)
all_words.extend(words) all_words.extend(words)
corpus.append(' '.join(words)) corpus.append(' '.join(words))
print("len(corpus):" + str(len(corpus))) print("len(corpus):" + str(len(corpus)))
......
...@@ -2,8 +2,8 @@ from django.http import JsonResponse ...@@ -2,8 +2,8 @@ from django.http import JsonResponse
from django.views.decorators.http import require_POST from django.views.decorators.http import require_POST
import pandas as pd import pandas as pd
from scenario_service.views import scenario, positive_negative_judgment_base_emotion_words from scenario_service.views import scenario, positive_negative_judgment_base_emotion_words
from model.base.views.token_authorize import * from platform_base.views.token_authorize import *
from model.base.models import ServiceManage from platform_base.models import ServiceManage
from platform_zzsn.settings import MEDIA_ROOT from platform_zzsn.settings import MEDIA_ROOT
UPLOAD_FOLDER = MEDIA_ROOT UPLOAD_FOLDER = MEDIA_ROOT
......
...@@ -10,7 +10,7 @@ import re ...@@ -10,7 +10,7 @@ import re
import jieba import jieba
import jieba.posseg # 词性获取 import jieba.posseg # 词性获取
import collections # 词频统计库 import collections # 词频统计库
from base.views import utils from platform_base.views import base_utils
from platform_zzsn.settings import * from platform_zzsn.settings import *
...@@ -241,7 +241,7 @@ class WordCount: ...@@ -241,7 +241,7 @@ class WordCount:
seg_list_exact = jieba.posseg.cut(string_data, HMM=True) # 精确模式分词+HMM seg_list_exact = jieba.posseg.cut(string_data, HMM=True) # 精确模式分词+HMM
object_list = [] object_list = []
# 去除停用词 # 去除停用词
stopwords_path = os.path.join(BASE_DIR, 'static/base/baidu_stopwords.txt') stopwords_path = os.path.join(BASE_DIR, 'static/platform_base/baidu_stopwords.txt')
with open(stopwords_path, 'r', encoding='UTF-8') as meaninglessFile: with open(stopwords_path, 'r', encoding='UTF-8') as meaninglessFile:
stopwords = set(meaninglessFile.read().split('\n')) stopwords = set(meaninglessFile.read().split('\n'))
stopwords.add(' ') stopwords.add(' ')
...@@ -266,7 +266,7 @@ class WordCount: ...@@ -266,7 +266,7 @@ class WordCount:
if __name__ == '__main__': if __name__ == '__main__':
pending_file = r'C:\Users\EDZ\Desktop\data1104.xlsx' pending_file = r'C:\Users\EDZ\Desktop\data1104.xlsx'
user_file = r'C:\Users\EDZ\Desktop\用户自定义词典_样例.txt' user_file = r'C:\Users\EDZ\Desktop\用户自定义词典_样例.txt'
doc_text_list = utils.read_excel(pending_file, user_file) doc_text_list = base_utils.read_excel(pending_file, user_file)
# print(doc_text_list) # print(doc_text_list)
text = '。'.join(doc_text_list) text = '。'.join(doc_text_list)
print("len(corpus):" + str(len(text))) print("len(corpus):" + str(len(text)))
......
...@@ -5,7 +5,7 @@ data_process: ...@@ -5,7 +5,7 @@ data_process:
test_size: 0.1 test_size: 0.1
random_state: 2021 random_state: 2021
embedding: embedding:
name: fxxl path: fxxl
title_weight: 5 title_weight: 5
title_feature_ratio: 0.1 title_feature_ratio: 0.1
content_feature_ratio: 0.2 content_feature_ratio: 0.2
......
...@@ -7,7 +7,7 @@ data_process: ...@@ -7,7 +7,7 @@ data_process:
random_state: 2021 random_state: 2021
min_content: 50 min_content: 50
embedding: embedding:
name: fxxl path: fxxl
transformer: tf transformer: tf
transformer_norm: l2 transformer_norm: l2
embedding_path: null embedding_path: null
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论