OjciecTadeusz commited on
Commit
cce0194
·
verified ·
1 Parent(s): 08ccb13

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ # Model configuration
9
+ MODEL_NAME = "nlptown/bert-base-multilingual-uncased-sentiment"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Initialize model and tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
15
+ classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=DEVICE)
16
+
17
+ class TextInput(BaseModel):
18
+ text: str
19
+
20
+ @app.post("/analyze-sentiment")
21
+ async def analyze_sentiment(input_data: TextInput):
22
+ try:
23
+ result = classifier(input_data.text)
24
+ return {
25
+ "sentiment": result[0]['label'],
26
+ "score": float(result[0]['score'])
27
+ }
28
+ except Exception as e:
29
+ raise HTTPException(status_code=500, detail=str(e))
30
+
31
+ # Przykład dla większego modelu (np. GPT-2)
32
+ MODEL_NAME_LARGE = "gpt2-large"
33
+ tokenizer_large = AutoTokenizer.from_pretrained(MODEL_NAME_LARGE)
34
+ model_large = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME_LARGE)
35
+
36
+ class GenerationInput(BaseModel):
37
+ prompt: str
38
+ max_length: int = 100
39
+
40
+ @app.post("/generate-text")
41
+ async def generate_text(input_data: GenerationInput):
42
+ try:
43
+ inputs = tokenizer_large(input_data.prompt, return_tensors="pt")
44
+ outputs = model_large.generate(
45
+ inputs["input_ids"],
46
+ max_length=input_data.max_length,
47
+ num_return_sequences=1,
48
+ no_repeat_ngram_size=2
49
+ )
50
+ generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
51
+ return {"generated_text": generated_text}
52
+ except Exception as e:
53
+ raise HTTPException(status_code=500, detail=str(e))
54
+
55
+ # Dodanie podstawowego health checka
56
+ @app.get("/health")
57
+ async def health_check():
58
+ return {"status": "healthy"}