Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -35,6 +35,7 @@ nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
|
|
35 |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
36 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
37 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
|
|
38 |
|
39 |
description = """
|
40 |
## Image-based Document QA
|
@@ -262,6 +263,26 @@ async def test_transcription(file: UploadFile = File(...)):
|
|
262 |
|
263 |
except Exception as e:
|
264 |
raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
# Set up CORS middleware
|
267 |
origins = ["*"] # or specify your list of allowed origins
|
|
|
35 |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
36 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
37 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
38 |
+
nlp_prompt_classification = pipeline("text2text-generation", model="google/flan-t5-large")
|
39 |
|
40 |
description = """
|
41 |
## Image-based Document QA
|
|
|
263 |
|
264 |
except Exception as e:
|
265 |
raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
|
266 |
+
|
267 |
+
@app.post("/prompt_classify/", description="Classify the provided statement into one of the predefined categories.")
|
268 |
+
async def prompt_classify_text(statement: str = Form(...)):
|
269 |
+
try:
|
270 |
+
# Predefined prompt with placeholders
|
271 |
+
prompt = (
|
272 |
+
"Please classify the statement in one of the following classifications: "
|
273 |
+
"Negative, Neutral, Positive"
|
274 |
+
f"Statement: {statement}"
|
275 |
+
)
|
276 |
+
|
277 |
+
# Generate the response based on the prompt
|
278 |
+
result = nlp_prompt_classification(prompt, max_length=50, num_return_sequences=1)
|
279 |
+
|
280 |
+
# Extract the generated classification from the response
|
281 |
+
classification = result[0]['generated_text'].strip()
|
282 |
+
|
283 |
+
return {"classification": classification}
|
284 |
+
except Exception as e:
|
285 |
+
return JSONResponse(content=f"Error in prompt classification: {str(e)}", status_code=500)
|
286 |
|
287 |
# Set up CORS middleware
|
288 |
origins = ["*"] # or specify your list of allowed origins
|