Spaces:
Sleeping
Sleeping
File size: 6,020 Bytes
c069edf 6133a63 f8544e9 7fa4c88 05c34a8 5d20e93 b9d94dc fadc2ea 6133a63 f8544e9 8665f6a e4502ec ad0cd53 e4502ec c9eef99 f8544e9 c069edf b9d94dc c9eef99 8665f6a 7fa4c88 f2af948 5d20e93 f2af948 7fa4c88 e4165c8 4f21ff8 8665f6a ad0cd53 5d20e93 ad0cd53 8665f6a 05c34a8 f8544e9 05c34a8 8665f6a 7fa4c88 f8544e9 f8c3935 e4502ec 1cbbb3e f8c3935 1cbbb3e 7fa4c88 e4502ec 6133a63 b9d94dc 7f7263d 9533a0b 7f7263d 8665f6a 7f7263d 8665f6a e4502ec ad0cd53 7f7263d 8665f6a f8544e9 b9d94dc fadc2ea db2e73b b9d94dc e4502ec fadc2ea b9d94dc 9f559e5 8665f6a e4502ec fc968b1 e4502ec fc968b1 e4502ec 365f24d e4502ec 365f24d 8665f6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import gc
import io
from llama_cpp import Llama
from concurrent.futures import ThreadPoolExecutor, as_completed
from fastapi import FastAPI, Request, HTTPException, Lifespan
from fastapi.responses import JSONResponse
from tqdm import tqdm
from dotenv import load_dotenv
from pydantic import BaseModel
from huggingface_hub import hf_hub_download, login
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import nltk
import uvicorn
import psutil
import torch
import io
nltk.download('punkt')
nltk.download('stopwords')
load_dotenv()
app = FastAPI()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if HUGGINGFACE_TOKEN:
login(token=HUGGINGFACE_TOKEN)
model_configs = [
# ... (Your model configurations remain the same) ...
]
global_data = {'model_configs': model_configs, 'training_data': io.StringIO()}
class ModelManager:
def __init__(self):
self.models = {}
self.load_models()
def load_models(self):
for config in tqdm(global_data['model_configs'], desc="Loading models"):
model_name = config['name']
if model_name not in self.models:
try:
model_bytes = hf_hub_download(repo_id=config['repo_id'], filename=config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
model = Llama(model_path=io.BytesIO(model_bytes), n_ctx=512, n_gpu=1) # Correct: Use io.BytesIO
self.models[model_name] = model
print(f"Model '{model_name}' loaded successfully.")
except Exception as e:
print(f"Error loading model {model_name}: {e}")
self.models[model_name] = None
finally:
gc.collect()
def get_model(self, model_name: str):
return self.models.get(model_name)
model_manager = ModelManager()
class ChatRequest(BaseModel):
message: str
async def generate_model_response(model, inputs: str) -> str:
try:
if model:
response = model(inputs, max_tokens=150)
return response['choices'][0]['text'].strip()
else:
return "Model not loaded"
except Exception as e:
return f"Error: Could not generate a response. Details: {e}"
async def process_message(message: str) -> dict:
inputs = message.strip()
responses = {}
loaded_models = [model for model in global_data['model_configs'] if model_manager.get_model(model['name'])]
with ThreadPoolExecutor(max_workers=min(len(loaded_models), 4)) as executor:
futures = [executor.submit(generate_model_response, model_manager.get_model(config['name']), inputs) for config in loaded_models]
for i, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Generating responses")):
try:
model_name = loaded_models[i]['name']
responses[model_name] = future.result()
except Exception as e:
responses[model_name] = f"Error processing {model_name}: {e}"
stop_words = set(stopwords.words('english'))
vectorizer = TfidfVectorizer(tokenizer=word_tokenize, stop_words=stop_words)
reference_text = message
response_texts = list(responses.values())
if response_texts:
tfidf_matrix = vectorizer.fit_transform([reference_text] + response_texts)
similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])
best_response_index = similarities.argmax()
best_response_model = list(responses.keys())[best_response_index]
best_response_text = response_texts[best_response_index]
return {"best_response": {"model": best_response_model, "text": best_response_text}, "all_responses": responses}
else:
return {"best_response": {"model": None, "text": "No models loaded successfully."}, "all_responses": responses}
@app.post("/generate_multimodel")
async def api_generate_multimodel(request: Request):
try:
data = await request.json()
message = data.get("message")
if not message:
raise HTTPException(status_code=400, detail="Missing message")
response = await process_message(message)
return JSONResponse(response)
except HTTPException as e:
raise e
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
async def startup():
pass
async def shutdown():
gc.collect()
app.add_event_handler("startup", startup)
app.add_event_handler("shutdown", shutdown)
def release_resources():
try:
torch.cuda.empty_cache()
gc.collect()
except Exception as e:
print(f"Failed to release resources: {e}")
def resource_manager():
MAX_RAM_PERCENT = 20
MAX_CPU_PERCENT = 20
MAX_GPU_PERCENT = 20
MAX_RAM_MB = 2048
while True:
try:
virtual_mem = psutil.virtual_memory()
current_ram_percent = virtual_mem.percent
current_ram_mb = virtual_mem.used / (1024 * 1024)
if current_ram_percent > MAX_RAM_PERCENT or current_ram_mb > MAX_RAM_MB:
release_resources()
current_cpu_percent = psutil.cpu_percent()
if current_cpu_percent > MAX_CPU_PERCENT:
psutil.Process(os.getpid()).nice()
if torch.cuda.is_available():
gpu = torch.cuda.current_device()
gpu_mem = torch.cuda.memory_percent(gpu)
if gpu_mem > MAX_GPU_PERCENT:
release_resources()
except Exception as e:
print(f"Error in resource manager: {e}")
if __name__ == "__main__":
import threading
resource_thread = threading.Thread(target=resource_manager)
resource_thread.daemon = True
resource_thread.start()
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port) |