bevelapi / main.py
BeveledCube's picture
Update main.py
054e3fb verified
raw
history blame
1.33 kB
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi import FastAPI
import os
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
app = FastAPI()
name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
customGen = False
# microsoft/DialoGPT-small
# microsoft/DialoGPT-medium
# microsoft/DialoGPT-large
# mistralai/Mixtral-8x7B-Instruct-v0.1
# Load the Hugging Face GPT-2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)
class req(BaseModel):
prompt: str
length: int
@app.get("/")
def read_root():
return FileResponse(path="templates/index.html", media_type="text/html")
@app.post("/api")
def read_root(data: req):
print("Prompt:", data.prompt)
print("Length:", data.length)
input_text = data.prompt
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate output using the model
output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data