ShabazKnowde commited on
Commit
2240148
·
verified ·
1 Parent(s): 3edebfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -1,25 +1,12 @@
1
  from pydantic import BaseModel
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
3
  from happytransformer import HappyTextToText, TTSettings
4
- from fastapi import FastAPI
5
 
6
  app = FastAPI()
7
- class Generate(BaseModel):
8
- text:str
9
-
10
- def generate_text(prompt: str):
11
- if prompt == "":
12
- return {"detail": "Please provide a prompt."}
13
- else:
14
-
15
-
16
- happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
17
-
18
- args = TTSettings(num_beams=5, min_length=1, max_length=100000)
19
-
20
- result = happy_tt.generate_text(f"grammar: {prompt}", args=args)
21
- return result.text
22
 
 
23
  app.add_middleware(
24
  CORSMiddleware,
25
  allow_origins=["*"],
@@ -28,10 +15,23 @@ app.add_middleware(
28
  allow_headers=["*"],
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @app.get("/", tags=["Home"])
32
  def api_home():
33
  return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
34
 
35
- @app.post("/api/generate", tags=["Generate"], response_model=Generate)
36
- def inference(input_prompt: str):
37
- return generate_text(prompt=input_prompt)
 
1
  from pydantic import BaseModel
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi import FastAPI, HTTPException
4
+
5
  from happytransformer import HappyTextToText, TTSettings
 
6
 
7
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Enable CORS
10
  app.add_middleware(
11
  CORSMiddleware,
12
  allow_origins=["*"],
 
15
  allow_headers=["*"],
16
  )
17
 
18
+ class GenerateInput(BaseModel):
19
+ prompt: str
20
+
21
+ happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
22
+ args = TTSettings(num_beams=5, min_length=1, max_length=100000)
23
+
24
+ def generate_text(prompt: str):
25
+ if not prompt.strip():
26
+ raise HTTPException(status_code=400, detail="Please provide a non-empty prompt.")
27
+
28
+ result = happy_tt.generate_text(f"grammar: {prompt}", args=args)
29
+ return result.text
30
+
31
  @app.get("/", tags=["Home"])
32
  def api_home():
33
  return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
34
 
35
+ @app.post("/api/generate", tags=["Generate"])
36
+ def inference(data: GenerateInput):
37
+ return generate_text(data.prompt)