Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
app = FastAPI() | |
# Model configuration | |
MODEL_NAME = "nlptown/bert-base-multilingual-uncased-sentiment" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize sentiment analysis model | |
sentiment_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
sentiment_classifier = pipeline( | |
"sentiment-analysis", | |
model=MODEL_NAME, | |
tokenizer=sentiment_tokenizer, | |
device=DEVICE | |
) | |
# Initialize GPT-2 for text generation | |
MODEL_NAME_LARGE = "gpt2-large" | |
generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_LARGE) | |
generation_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_LARGE).to(DEVICE) | |
class TextInput(BaseModel): | |
text: str | |
class GenerationInput(BaseModel): | |
prompt: str | |
max_length: int = 100 | |
async def analyze_sentiment(input_data: TextInput): | |
try: | |
result = sentiment_classifier(input_data.text) | |
return { | |
"sentiment": result[0]['label'], | |
"score": float(result[0]['score']) | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_text(input_data: GenerationInput): | |
try: | |
inputs = generation_tokenizer( | |
input_data.prompt, | |
return_tensors="pt" | |
).to(DEVICE) | |
outputs = generation_model.generate( | |
inputs["input_ids"], | |
max_length=input_data.max_length, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
pad_token_id=generation_tokenizer.eos_token_id | |
) | |
generated_text = generation_tokenizer.decode( | |
outputs[0], | |
skip_special_tokens=True | |
) | |
return {"generated_text": generated_text} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"sentiment_model": MODEL_NAME, | |
"generation_model": MODEL_NAME_LARGE, | |
"device": str(DEVICE) | |
} | |
# Dodaj to na końcu pliku | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |