harry85's picture
Upload app.py
8290afa verified
# Install the necessary packages
# pip install accelerate transformers fastapi pydantic torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel
from fastapi import FastAPI
# Import the required library
from transformers import pipeline
# Initialize the FastAPI app
app = FastAPI(docs_url="/")
# Define the request model
class RequestModel(BaseModel):
input: str
# Define a greeting endpoint
@app.get("/")
def greet_json():
return {"message": "working..."}
# Define the text generation endpoint
@app.post("/generatetext")
def get_response(request: RequestModel):
# Define the task and model
task = "text-generation"
model_name = "gpt2"
# Define the input text, maximum output length, and the number of return sequences
input_text = request.input
max_output_length = 50
num_of_return_sequences = 1
# Initialize the text generation pipeline
text_generator = pipeline(
task,
model=model_name
)
# Generate text sequences
generated_texts = text_generator(
input_text,
max_length=max_output_length,
num_return_sequences=num_of_return_sequences
)
# Extract and return the generated text
generated_text = generated_texts[0]['generated_text']
return {"generated_text": generated_text}
# To run the FastAPI app, use the command: uvicorn <filename>:app --reload