diff --git a/flask_batch.py b/flask_batch.py index fbd8be2..16f9e20 100644 --- a/flask_batch.py +++ b/flask_batch.py @@ -1,3 +1,5 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" from flask import Flask, jsonify from flask import request from transformers import pipeline diff --git a/predict.py b/predict.py index 4ac8318..799db26 100644 --- a/predict.py +++ b/predict.py @@ -1,5 +1,4 @@ import time - from vllm import LLM, SamplingParams prompts = [