OnlyCheeini's picture
Update app.py
bb9d04d verified
raw
history blame
1.72 kB
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
def greet(name, req: gr.Request):
return f"{req.headers=}"
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()
# Disable hf_transfer
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "false"
app = FastAPI()
# Load your fine-tuned model and tokenizer
model_name = "OnlyCheeini/greesychat-turbo"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Check if a GPU is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
class OpenAIRequest(BaseModel):
model: str
prompt: str
max_tokens: int = 64
temperature: float = 0.7
top_p: float = 0.9
class OpenAIResponse(BaseModel):
choices: list
@app.post("/v1/completions", response_model=OpenAIResponse)
async def generate_text(request: OpenAIRequest):
if request.model != model_name:
raise HTTPException(status_code=400, detail="Model not found")
inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_length=inputs['input_ids'].shape[1] + request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return OpenAIResponse(choices=[{"text": generated_text}])
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)