import os os.environ["CUDA_VISIBLE_DEVICES"] = "3,1" from flask import Flask, render_template, request, redirect, url_for, jsonify from werkzeug.utils import secure_filename import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig app = Flask(__name__) import time import re # 上传文件存储目录 UPLOAD_FOLDER = 'uploads' app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER # 正则表达式 RE_CHINA_NUMS = "[1-9].(.*)" # 允许的文件类型 ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} model_path = "/home/majiahui/project/models-llm/Qwen-VL-Chat" tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda", trust_remote_code=True, bf16=True).eval() model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) prompt_picture = { "1": "图中的商品:{},有什么突出亮点和卖点,请分条列举出来,要求亮点或者卖点要用一个词总结,冒号后面在进行解释,例如:1. 时尚黑色:图中的鞋子是黑色的,符合时尚潮流,适合不同场合的穿搭。", "2": "图中的商品:{},有什么亮点,写一段营销话语", "3": "图中的商品:{},有以下亮点:\n{}\n根据这些优势亮点,写一段营销文本让商品卖的更好", "4": "图中的商品:{},有哪些不足之处可以改进?", "5": "图中{}的渲染图做哪些调整可以更吸引消费者", "6": "根据图中的商品:{},生成5个商品名称,要求商品名称格式中包含的信息(品牌名,产品名,细分产品种类词,三到五个卖点和形容词),请分条列举出来,例如:1. xxx \n2. xxx \n3. xxx \n4. xxx \n5. xxx", } # prompt_text = { # "1": "图中{}有什么突出亮点,请列举出来", # "2": "图中{}有什么亮点,写一段营销话语", # "3": "图中{}有以下亮点:\n{}\n根据这些优势亮点,写一段营销文本让商品买的更好", # "4": "图中{}有哪些不足之处可以改进?", # "5": "图中{}的渲染图做哪些调整可以更吸引消费者", # "5": "图中{}的渲染图做哪些调整可以更吸引消费者", # } class log: def __init__(self): pass def log(*args, **kwargs): format = '%Y/%m/%d-%H:%M:%S' format_h = '%Y-%m-%d' value = time.localtime(int(time.time())) dt = time.strftime(format, value) dt_log_file = time.strftime(format_h, value) log_file = 'log_file/access-%s' % dt_log_file + ".log" if not os.path.exists(log_file): with open(os.path.join(log_file), 'w', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) else: with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) # 检查文件扩展名 def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def picyure_model_predict(image_path, prompt): query = tokenizer.from_list_format([ {'image': image_path}, {'text': prompt}, ]) response, history = model.chat(tokenizer, query=query, history=None) return response def picture_main(path, commodity, type, additional): if type == "1": result_list_len = False dan_result_geshi = True dan_result_geshi_maohao = True prompy_text = prompt_picture[type] prompy_text = prompy_text.format(commodity) while True: result = picyure_model_predict(path, prompy_text) result_list = str(result).split("\n") if len(result_list) > 3: result_list_len = True for i in result_list: response_re = re.findall(RE_CHINA_NUMS, i) if response_re == []: dan_result_geshi = False break if ":" not in i: dan_result_geshi_maohao = False break if result_list_len == True and dan_result_geshi == True and dan_result_geshi_maohao == True: break maidian_list = [] for i in result_list: response_re = re.findall(RE_CHINA_NUMS, i) guanjianci = response_re[0].split(":") maidian_list.append([i, guanjianci]) return maidian_list elif type == "2": prompy_text = prompt_picture[type] prompy_text = prompy_text.format(commodity) result = picyure_model_predict(path, prompy_text) return result elif type == "3": prompy_text = prompt_picture[type] prompy_text = prompy_text.format(commodity, additional) result = picyure_model_predict(path, prompy_text) return result elif type == "4": prompy_text = prompt_picture[type] prompy_text = prompy_text.format(commodity) result = picyure_model_predict(path, prompy_text) return result elif type == "5": prompy_text = prompt_picture[type] prompy_text = prompy_text.format(commodity) result = picyure_model_predict(path, prompy_text) return result elif type == "6": prompy_text = prompt_picture[type] prompy_text = prompy_text.format(commodity) result = picyure_model_predict(path, prompy_text) return result elif type == "7": prompy_text = additional result = picyure_model_predict(path, prompy_text) return result else: return "1111" # 文件上传处理 @app.route('/vl_chat', methods=['POST']) def upload_file(): if 'file' not in request.files: return "1" file = request.files.get('file') commodity = request.form.get('commodity') type = request.form.get('type') additional = request.form.get("additional") if file and allowed_file(file.filename): filename = secure_filename(file.filename) path = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(path) # 业务逻辑 try: result = picture_main(path, commodity, type, additional) return_text = {"texts": result, "probabilities": None, "status_code": 200} except: return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 400} log.log('start at', 'filename:{}, commodity:{}, type:{}, additional:{}, result:{}'.format( path, commodity, type, additional, return_text)) return jsonify(return_text) else: return "不允许的文件类型" # 无文件上传 # @app.route('/chat', methods=['POST']) # def upload_file(): # # type = request.files.get('type') # describe = request.form.get("describe") # advantage = request.form.get("dadvantage") # # return "1" if __name__ == "__main__": app.run(host="0.0.0.0", port=19000, threaded=True)