You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

194 lines
6.9 KiB

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3,1"
from flask import Flask, render_template, request, redirect, url_for, jsonify
from werkzeug.utils import secure_filename
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
app = Flask(__name__)
import time
import re
# 上传文件存储目录
UPLOAD_FOLDER = 'uploads'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# 正则表达式
RE_CHINA_NUMS = "[1-9].(.*)"
# 允许的文件类型
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
model_path = "/home/majiahui/project/models-llm/Qwen-VL-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda", trust_remote_code=True, bf16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
prompt_picture = {
"1": "图中的商品:{},有什么突出亮点和卖点,请分条列举出来,要求亮点或者卖点要用一个词总结,冒号后面在进行解释,例如:1. 时尚黑色:图中的鞋子是黑色的,符合时尚潮流,适合不同场合的穿搭。",
"2": "图中的商品:{},有什么亮点,写一段营销话语",
"3": "图中的商品:{},有以下亮点:\n{}\n根据这些优势亮点,写一段营销文本让商品卖的更好",
"4": "图中的商品:{},有哪些不足之处可以改进?",
"5": "图中{}的渲染图做哪些调整可以更吸引消费者",
"6": "根据图中的商品:{},生成5个商品名称,要求商品名称格式中包含的信息(品牌名,产品名,细分产品种类词,三到五个卖点和形容词),请分条列举出来,例如:1. xxx \n2. xxx \n3. xxx \n4. xxx \n5. xxx",
}
# prompt_text = {
# "1": "图中{}有什么突出亮点,请列举出来",
# "2": "图中{}有什么亮点,写一段营销话语",
# "3": "图中{}有以下亮点:\n{}\n根据这些优势亮点,写一段营销文本让商品买的更好",
# "4": "图中{}有哪些不足之处可以改进?",
# "5": "图中{}的渲染图做哪些调整可以更吸引消费者",
# "5": "图中{}的渲染图做哪些调整可以更吸引消费者",
# }
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
# 检查文件扩展名
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def picyure_model_predict(image_path, prompt):
query = tokenizer.from_list_format([
{'image': image_path},
{'text': prompt},
])
response, history = model.chat(tokenizer, query=query, history=None)
return response
def picture_main(path, commodity, type, additional):
if type == "1":
result_list_len = False
dan_result_geshi = True
dan_result_geshi_maohao = True
prompy_text = prompt_picture[type]
prompy_text = prompy_text.format(commodity)
while True:
result = picyure_model_predict(path, prompy_text)
result_list = str(result).split("\n")
if len(result_list) > 3:
result_list_len = True
for i in result_list:
response_re = re.findall(RE_CHINA_NUMS, i)
if response_re == []:
dan_result_geshi = False
break
if "" not in i:
dan_result_geshi_maohao = False
break
if result_list_len == True and dan_result_geshi == True and dan_result_geshi_maohao == True:
break
maidian_list = []
for i in result_list:
response_re = re.findall(RE_CHINA_NUMS, i)
guanjianci = response_re[0].split("")
maidian_list.append([i, guanjianci])
return maidian_list
elif type == "2":
prompy_text = prompt_picture[type]
prompy_text = prompy_text.format(commodity)
result = picyure_model_predict(path, prompy_text)
return result
elif type == "3":
prompy_text = prompt_picture[type]
prompy_text = prompy_text.format(commodity, additional)
result = picyure_model_predict(path, prompy_text)
return result
elif type == "4":
prompy_text = prompt_picture[type]
prompy_text = prompy_text.format(commodity)
result = picyure_model_predict(path, prompy_text)
return result
elif type == "5":
prompy_text = prompt_picture[type]
prompy_text = prompy_text.format(commodity)
result = picyure_model_predict(path, prompy_text)
return result
elif type == "6":
prompy_text = prompt_picture[type]
prompy_text = prompy_text.format(commodity)
result = picyure_model_predict(path, prompy_text)
return result
elif type == "7":
prompy_text = additional
result = picyure_model_predict(path, prompy_text)
return result
else:
return "1111"
# 文件上传处理
@app.route('/vl_chat', methods=['POST'])
def upload_file():
if 'file' not in request.files:
return "1"
file = request.files.get('file')
commodity = request.form.get('commodity')
type = request.form.get('type')
additional = request.form.get("additional")
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(path)
# 业务逻辑
try:
result = picture_main(path, commodity, type, additional)
return_text = {"texts": result, "probabilities": None, "status_code": 200}
except:
return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 400}
log.log('start at',
'filename:{}, commodity:{}, type:{}, additional:{}, result:{}'.format(
path, commodity, type, additional, return_text))
return jsonify(return_text)
else:
return "不允许的文件类型"
# 无文件上传
# @app.route('/chat', methods=['POST'])
# def upload_file():
#
# type = request.files.get('type')
# describe = request.form.get("describe")
# advantage = request.form.get("dadvantage")
#
# return "1"
if __name__ == "__main__":
app.run(host="0.0.0.0", port=19000, threaded=True)