Hjgugugjhuhjggg commited on
Commit
9800fab
·
verified ·
1 Parent(s): 061bbfa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import uvicorn
4
+ from fastapi import FastAPI, Request
5
+ from fastapi.responses import JSONResponse
6
+ from langchain_community.llms import VLLM
7
+ from gptcache import Cache
8
+ from gptcache.manager.factory import manager_factory
9
+ from gptcache.processor.pre import get_prompt
10
+ from langchain_community.cache import GPTCache
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ from sentence_transformers import SentenceTransformer
13
+ import torch
14
+ from langchain.callbacks import get_openai_callback
15
+ import langchain
16
+
17
+ app = FastAPI()
18
+
19
+ def get_hashed_name(name):
20
+ return hashlib.sha256(name.encode()).hexdigest()
21
+
22
+ def init_gptcache(cache_obj, llm):
23
+ hashed_llm = get_hashed_name(llm)
24
+ cache_obj.init(pre_embedding_func=get_prompt, data_manager=manager_factory(manager="map", data_dir=f"map_cache_{hashed_llm}"))
25
+
26
+ cache = Cache()
27
+ langchain.llm_cache = GPTCache(cache=cache)
28
+
29
+ hf_token = os.environ.get("HF_TOKEN")
30
+
31
+ llm_models = {
32
+ "TinyLlama": VLLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token),
33
+ "yi-coder": VLLM(model="01-ai/Yi-Coder-1.5B", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token),
34
+ "llama": VLLM(model="meta-llama/Llama-3.2-3B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token),
35
+ "qwen": VLLM(model="Qwen/Qwen2.5-1.5B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token),
36
+ }
37
+
38
+ for llm_name, llm in llm_models.items():
39
+ init_gptcache(cache, llm_name)
40
+
41
+ try:
42
+ sentence_model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
43
+ except Exception as e:
44
+ print(f"Error loading SentenceTransformer: {e}")
45
+ sentence_model = None
46
+
47
+ @app.get("/")
48
+ def read_root():
49
+ return {"Hello": "World"}
50
+
51
+ @app.post("/v1/generateText")
52
+ async def generateText(request: Request):
53
+ request_dict = await request.json()
54
+ prompt = request_dict.pop("prompt")
55
+ max_tokens = request_dict.get("max_tokens", -1)
56
+
57
+ all_responses = {}
58
+ for model_name, llm in llm_models.items():
59
+ try:
60
+ with get_openai_callback() as cb:
61
+ if max_tokens == -1:
62
+ full_response = llm(prompt)
63
+ else:
64
+ full_response = ""
65
+ current_prompt = prompt
66
+ while True:
67
+ response_part = llm(current_prompt, max_new_tokens=max_tokens)
68
+ full_response += response_part
69
+ if len(full_response) >= max_tokens or response_part == "":
70
+ break
71
+ current_prompt = full_response
72
+ print(cb)
73
+ all_responses[model_name] = full_response
74
+ print(f"Model {model_name}: {full_response}")
75
+ except Exception as e:
76
+ print(f"Error with model {model_name}: {e}")
77
+
78
+ if not all_responses:
79
+ return JSONResponse({"error": "All models failed to generate text"}, status_code=500)
80
+
81
+ if sentence_model:
82
+ embeddings = sentence_model.encode(list(all_responses.values()))
83
+ similarities = cosine_similarity(embeddings)
84
+ avg_similarity = similarities.mean(axis=0)
85
+ best_model = list(all_responses.keys())[avg_similarity.argmax()]
86
+ best_response = all_responses[best_model]
87
+ else:
88
+ best_model = list(all_responses.keys())[0]
89
+ best_response = all_responses[best_model]
90
+
91
+ return JSONResponse({"best_model": best_model, "text": best_response, "all_responses": all_responses})
92
+
93
+ if __name__ == "__main__":
94
+ uvicorn.run(app, host="0.0.0.0", port=5001)