ShabazKnowde commited on
Commit
4f9c611
·
verified ·
1 Parent(s): 65e2ad9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -13
app.py CHANGED
@@ -1,20 +1,41 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
 
 
 
3
  from happytransformer import HappyTextToText, TTSettings
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- app = FastAPI()
6
 
7
- happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
 
 
 
 
 
8
 
9
- class InputText(BaseModel):
10
- txt: str
 
 
 
 
 
11
 
12
- @app.post("/correct_grammar/")
13
- def correct_grammar(input_text: InputText):
14
- args = TTSettings(num_beams=5, min_length=1, max_length=100000)
15
- corrected_text = happy_tt.generate_text(f"grammar: {input_text.txt}", args=args)
16
- return {"corrected_text": corrected_text.text}
17
 
18
- if __name__ == "__main__":
19
- import uvicorn
20
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
  from pydantic import BaseModel
2
+
3
+ from .ConfigEnv import config
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  from happytransformer import HappyTextToText, TTSettings
6
+ from langchain.llms import Clarifai
7
+ from langchain.chains import LLMChain
8
+ from langchain.prompts import PromptTemplate
9
+ from TextGen import app
10
+
11
+ class Generate(BaseModel):
12
+ text:str
13
+
14
+ def generate_text(prompt: str):
15
+ if prompt == "":
16
+ return {"detail": "Please provide a prompt."}
17
+ else:
18
 
 
19
 
20
+ happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
21
+
22
+ args = TTSettings(num_beams=5, min_length=1, max_length=100000)
23
+
24
+ result = happy_tt.generate_text(f"grammar: {prompt}", args=args)
25
+ return result.text
26
 
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
 
35
+ @app.get("/", tags=["Home"])
36
+ def api_home():
37
+ return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
 
 
38
 
39
+ @app.post("/api/generate", tags=["Generate"], response_model=Generate)
40
+ def inference(input_prompt: str):
41
+ return generate_text(prompt=input_prompt)