Hhhgg / app.py
Hjgugugjhuhjggg's picture
Update app.py
a906c19 verified
raw
history blame
3.73 kB
import os
import hashlib
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from langchain_community.llms import VLLM
from gptcache import Cache
from gptcache.manager.factory import manager_factory
from gptcache.processor.pre import get_prompt
from langchain_community.callbacks.manager import get_openai_callback
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import torch
import langchain
app = FastAPI()
def get_hashed_name(name):
return hashlib.sha256(name.encode()).hexdigest()
def init_gptcache(cache_obj, llm):
hashed_llm = get_hashed_name(llm)
cache_obj.init(pre_embedding_func=get_prompt, data_manager=manager_factory(manager="map", data_dir=f"map_cache_{hashed_llm}"))
cache = Cache()
hf_token = os.environ.get("HF_TOKEN")
llm_models = {
"TinyLlama": VLLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token, device="cpu"),
"yi-coder": VLLM(model="01-ai/Yi-Coder-1.5B", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token, device="cpu"),
"llama": VLLM(model="meta-llama/Llama-3.2-3B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token, device="cpu"),
"qwen": VLLM(model="Qwen/Qwen2.5-1.5B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token, device="cpu"),
}
for llm_name, llm in llm_models.items():
init_gptcache(cache, llm_name)
langchain.llm_cache = langchain.cache.GPTCache(session=cache)
try:
sentence_model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
except Exception as e:
print(f"Error loading SentenceTransformer: {e}")
sentence_model = None
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/v1/generateText")
async def generateText(request: Request):
request_dict = await request.json()
prompt = request_dict.pop("prompt")
max_tokens = request_dict.get("max_tokens", -1)
all_responses = {}
for model_name, llm in llm_models.items():
try:
with get_openai_callback() as cb:
if max_tokens == -1:
full_response = llm(prompt)
else:
full_response = ""
current_prompt = prompt
while True:
response_part = llm(current_prompt, max_new_tokens=max_tokens)
full_response += response_part
if len(full_response) >= max_tokens or response_part == "":
break
current_prompt = full_response
print(cb)
all_responses[model_name] = full_response
print(f"Model {model_name}: {full_response}")
except Exception as e:
print(f"Error with model {model_name}: {e}")
if not all_responses:
return JSONResponse({"error": "All models failed to generate text"}, status_code=500)
if sentence_model:
embeddings = sentence_model.encode(list(all_responses.values()))
similarities = cosine_similarity(embeddings)
avg_similarity = similarities.mean(axis=0)
best_model = list(all_responses.keys())[avg_similarity.argmax()]
best_response = all_responses[best_model]
else:
best_model = list(all_responses.keys())[0]
best_response = all_responses[best_model]
return JSONResponse({"best_model": best_model, "text": best_response, "all_responses": all_responses})
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)