MJobe commited on
Commit
8601b67
·
verified ·
1 Parent(s): 50e0c0e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -0
main.py CHANGED
@@ -35,6 +35,7 @@ nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
35
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
36
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
37
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
 
38
 
39
  description = """
40
  ## Image-based Document QA
@@ -262,6 +263,26 @@ async def test_transcription(file: UploadFile = File(...)):
262
 
263
  except Exception as e:
264
  raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  # Set up CORS middleware
267
  origins = ["*"] # or specify your list of allowed origins
 
35
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
36
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
37
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
38
+ nlp_prompt_classification = pipeline("text2text-generation", model="google/flan-t5-large")
39
 
40
  description = """
41
  ## Image-based Document QA
 
263
 
264
  except Exception as e:
265
  raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
266
+
267
+ @app.post("/prompt_classify/", description="Classify the provided statement into one of the predefined categories.")
268
+ async def prompt_classify_text(statement: str = Form(...)):
269
+ try:
270
+ # Predefined prompt with placeholders
271
+ prompt = (
272
+ "Please classify the statement in one of the following classifications: "
273
+ "Negative, Neutral, Positive"
274
+ f"Statement: {statement}"
275
+ )
276
+
277
+ # Generate the response based on the prompt
278
+ result = nlp_prompt_classification(prompt, max_length=50, num_return_sequences=1)
279
+
280
+ # Extract the generated classification from the response
281
+ classification = result[0]['generated_text'].strip()
282
+
283
+ return {"classification": classification}
284
+ except Exception as e:
285
+ return JSONResponse(content=f"Error in prompt classification: {str(e)}", status_code=500)
286
 
287
  # Set up CORS middleware
288
  origins = ["*"] # or specify your list of allowed origins