File size: 3,724 Bytes
9800fab
 
 
 
 
2409660
9800fab
 
 
a906c19
9800fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a906c19
 
 
 
9800fab
 
 
 
 
4c939b4
 
9800fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c939b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import hashlib
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from langchain.llms import VLLM
from gptcache import Cache
from gptcache.manager.factory import manager_factory
from gptcache.processor.pre import get_prompt
from langchain_community.callbacks.manager import get_openai_callback
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import torch
import langchain

app = FastAPI()

def get_hashed_name(name):
    return hashlib.sha256(name.encode()).hexdigest()

def init_gptcache(cache_obj, llm):
    hashed_llm = get_hashed_name(llm)
    cache_obj.init(pre_embedding_func=get_prompt, data_manager=manager_factory(manager="map", data_dir=f"map_cache_{hashed_llm}"))

cache = Cache()

hf_token = os.environ.get("HF_TOKEN")

llm_models = {
    "TinyLlama": VLLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token, device="cpu"),
    "yi-coder": VLLM(model="01-ai/Yi-Coder-1.5B", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token, device="cpu"),
    "llama": VLLM(model="meta-llama/Llama-3.2-3B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.1, use_auth_token=hf_token, device="cpu"),
    "qwen": VLLM(model="Qwen/Qwen2.5-1.5B-Instruct", trust_remote_code=True, max_new_tokens=50, temperature=0.6, use_auth_token=hf_token, device="cpu"),
}

for llm_name, llm in llm_models.items():
    init_gptcache(cache, llm_name)

langchain.llm_cache = langchain.cache.GPTCache(session=cache)

try:
    sentence_model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
except Exception as e:
    print(f"Error loading SentenceTransformer: {e}")
    sentence_model = None

@app.get("/")
def read_root():
    return {"Hello": "World"}

@app.post("/v1/generateText")
async def generateText(request: Request):
    request_dict = await request.json()
    prompt = request_dict.pop("prompt")
    max_tokens = request_dict.get("max_tokens", -1)

    all_responses = {}
    for model_name, llm in llm_models.items():
        try:
            with get_openai_callback() as cb:
                if max_tokens == -1:
                    full_response = llm(prompt)
                else:
                    full_response = ""
                    current_prompt = prompt
                    while True:
                        response_part = llm(current_prompt, max_new_tokens=max_tokens)
                        full_response += response_part
                        if len(full_response) >= max_tokens or response_part == "":
                            break
                        current_prompt = full_response
                print(cb)
                all_responses[model_name] = full_response
                print(f"Model {model_name}: {full_response}")
        except Exception as e:
            print(f"Error with model {model_name}: {e}")

    if not all_responses:
        return JSONResponse({"error": "All models failed to generate text"}, status_code=500)

    if sentence_model:
        embeddings = sentence_model.encode(list(all_responses.values()))
        similarities = cosine_similarity(embeddings)
        avg_similarity = similarities.mean(axis=0)
        best_model = list(all_responses.keys())[avg_similarity.argmax()]
        best_response = all_responses[best_model]
    else:
        best_model = list(all_responses.keys())[0]
        best_response = all_responses[best_model]

    return JSONResponse({"best_model": best_model, "text": best_response, "all_responses": all_responses})

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)