Spaces:
Sleeping
Sleeping
File size: 3,888 Bytes
d0e7d36 a71d086 d0e7d36 678a7bb f75753e 50c545e f5bef42 d0e7d36 f5bef42 a71d086 f5bef42 cb4a018 a71d086 cb4a018 a71d086 678a7bb a71d086 50c545e a71d086 d0e7d36 a71d086 f5bef42 d0e7d36 d2283fc d0e7d36 a71d086 d2283fc d0e7d36 a71d086 d2283fc 50c545e d2283fc 50c545e d0e7d36 a71d086 50c545e d2283fc f5bef42 a71d086 d2283fc a71d086 d2283fc a71d086 d2283fc a71d086 d0e7d36 33213c1 f75753e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from llama_cpp import Llama
from concurrent.futures import ThreadPoolExecutor, as_completed
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import os
from dotenv import load_dotenv
from pydantic import BaseModel
import requests
import traceback
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
global_data = {
'models': {},
'tokens': {
'eos': 'eos_token',
'pad': 'pad_token',
'padding': 'padding_token',
'unk': 'unk_token',
'bos': 'bos_token',
'sep': 'sep_token',
'cls': 'cls_token',
'mask': 'mask_token'
}
}
model_configs = [
{"repo_id": "Hjgugugjhuhjggg/mergekit-ties-tzamfyy-Q2_K-GGUF", "filename": "mergekit-ties-tzamfyy-q2_k.gguf", "name": "my_model"}
]
models = {}
def load_model(model_config):
model_name = model_config['name']
if model_name not in models:
try:
model = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
models[model_name] = model
global_data['models'] = models
return model
except Exception as e:
print(f"Error loading model {model_name}: {e}")
traceback.print_exc()
models[model_name] = None
return None
for config in model_configs:
load_model(config)
class ChatRequest(BaseModel):
message: str
max_tokens_per_part: int = 256
def normalize_input(input_text):
return input_text.strip()
def remove_duplicates(text):
lines = text.split('\n')
unique_lines = []
seen_lines = set()
for line in lines:
line = line.strip()
if line and line not in seen_lines:
unique_lines.append(line)
seen_lines.add(line)
return '\n'.join(unique_lines)
def generate_model_response(model, inputs, max_tokens_per_part):
try:
if model is None:
return []
full_response = ""
responses = []
response = model(inputs, max_tokens=max_tokens_per_part, stop=["\n\n"])
if 'choices' not in response or len(response['choices']) == 0 or 'text' not in response['choices'][0]:
return [f"Error: Invalid model response format"]
text = response['choices'][0]['text']
if text:
responses.append(remove_duplicates(text))
return responses
except Exception as e:
print(f"Error generating response: {e}")
traceback.print_exc()
return [f"Error: {e}"]
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/generate")
async def generate(request: ChatRequest):
inputs = normalize_input(request.message)
with ThreadPoolExecutor() as executor:
futures = [executor.submit(generate_model_response, model, inputs, request.max_tokens_per_part) for model in models.values()]
responses = [{'model': model_name, 'response': future.result()} for model_name, future in zip(models.keys(), as_completed(futures))]
unique_responses = {}
for response_set in responses:
model_name = response_set['model']
if model_name not in unique_responses:
unique_responses[model_name] = []
unique_responses[model_name].extend(response_set['response'])
formatted_response = ""
for model, response_parts in unique_responses.items():
formatted_response += f"**{model}:**\n"
for i, part in enumerate(response_parts):
formatted_response += f"Part {i+1}:\n{part}\n\n"
return {"response": formatted_response}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port) |