Hhhgg / app.py
Hjgugugjhuhjggg's picture
Update app.py
2409660 verified
raw
history blame
3.72 kB
import os
import hashlib
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from langchain.llms import VLLM
from gptcache import Cache
from gptcache.manager.factory import manager_factory
from gptcache.processor.pre import get_prompt
from langchain_community.callbacks.manager import get_openai_callback
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import torch
import langchain
app = FastAPI()
def get_hashed_name(name):
return hashlib.sha256(name.encode()).hexdigest()
def init_gptcache(cache_obj, llm):
hashed_llm = get_hashed_name(llm)
cache_obj.init(pre_embedding_func=get_prompt, data_manager=manager_factory(manager="map", data_dir=f"map_cache_{hashed_llm}"))
cache = Cache()
hf_token = os.environ.get("HF_TOKEN")
llm_models = {
"TinyLlama": VLLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token, device="cpu"),
"yi-coder": VLLM(model="01-ai/Yi-Coder-1.5B", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token, device="cpu"),
"llama": VLLM(model="meta-llama/Llama-3.2-3B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token, device="cpu"),
"qwen": VLLM(model="Qwen/Qwen2.5-1.5B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token, device="cpu"),
}
for llm_name, llm in llm_models.items():
init_gptcache(cache, llm_name)
langchain.llm_cache = langchain.cache.GPTCache(session=cache)
try:
sentence_model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
except Exception as e:
print(f"Error loading SentenceTransformer: {e}")
sentence_model = None
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/v1/generateText")
async def generateText(request: Request):
request_dict = await request.json()
prompt = request_dict.pop("prompt")
max_tokens = request_dict.get("max_tokens", -1)
all_responses = {}
for model_name, llm in llm_models.items():
try:
with get_openai_callback() as cb:
if max_tokens == -1:
full_response = llm(prompt)
else:
full_response = ""
current_prompt = prompt
while True:
response_part = llm(current_prompt, max_new_tokens=max_tokens)
full_response += response_part
if len(full_response) >= max_tokens or response_part == "":
break
current_prompt = full_response
print(cb)
all_responses[model_name] = full_response
print(f"Model {model_name}: {full_response}")
except Exception as e:
print(f"Error with model {model_name}: {e}")
if not all_responses:
return JSONResponse({"error": "All models failed to generate text"}, status_code=500)
if sentence_model:
embeddings = sentence_model.encode(list(all_responses.values()))
similarities = cosine_similarity(embeddings)
avg_similarity = similarities.mean(axis=0)
best_model = list(all_responses.keys())[avg_similarity.argmax()]
best_response = all_responses[best_model]
else:
best_model = list(all_responses.keys())[0]
best_response = all_responses[best_model]
return JSONResponse({"best_model": best_model, "text": best_response, "all_responses": all_responses})
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)