vrkforever commited on
Commit
afcca73
1 Parent(s): 88d9793

made changes to model

Browse files
Files changed (2) hide show
  1. app.py +17 -5
  2. testmodel.py +14 -0
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import nltk
 
 
2
  from fastapi import FastAPI
3
  from fastapi.responses import JSONResponse
4
  from pydantic import BaseModel
@@ -8,9 +10,19 @@ import joblib
8
  nltk.download('wordnet', quiet=True)
9
  nltk.download('stopwords', quiet=True)
10
 
 
 
 
11
  # Load the trained model
12
  model = joblib.load('disaster_classification_model.joblib')
13
 
 
 
 
 
 
 
 
14
  app = FastAPI()
15
 
16
  class TextRequest(BaseModel):
@@ -18,11 +30,11 @@ class TextRequest(BaseModel):
18
 
19
  @app.post("/predict")
20
  async def predict(request: TextRequest):
21
- text = request.text
22
- # The preprocessing is now handled by the loaded pipeline
23
- prediction = model.predict([text])[0]
24
- result = "disaster" if prediction == 1 else "not"
25
- return JSONResponse(content={"output": result})
26
 
27
  @app.get("/")
28
  async def root():
 
1
  import nltk
2
+ from nltk.corpus import stopwords
3
+ from nltk.stem import WordNetLemmatizer
4
  from fastapi import FastAPI
5
  from fastapi.responses import JSONResponse
6
  from pydantic import BaseModel
 
10
  nltk.download('wordnet', quiet=True)
11
  nltk.download('stopwords', quiet=True)
12
 
13
+ # Initialize lemmatizer
14
+ lemmatizer = WordNetLemmatizer()
15
+
16
  # Load the trained model
17
  model = joblib.load('disaster_classification_model.joblib')
18
 
19
+ def improved_preprocess(text):
20
+ text = text.lower()
21
+ text = ''.join([char for char in text if char not in string.punctuation])
22
+ words = text.split()
23
+ words = [lemmatizer.lemmatize(word) for word in words if word not in stopwords.words('english')]
24
+ return ' '.join(words)
25
+
26
  app = FastAPI()
27
 
28
  class TextRequest(BaseModel):
 
30
 
31
  @app.post("/predict")
32
  async def predict(request: TextRequest):
33
+ text = request.text
34
+ new_text_processed = [improved_preprocess(text)]
35
+ prediction = model.predict(new_text_processed)
36
+ result = "disaster" if prediction == 1 else "not"
37
+ return JSONResponse(content={"output": result})
38
 
39
  @app.get("/")
40
  async def root():
testmodel.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+
3
+ # Load the trained model
4
+ model = joblib.load('disaster_classification_model.joblib')
5
+
6
+ # Sample text to test
7
+ text = "This is a test message to check the model."
8
+
9
+ # Make prediction
10
+ prediction = model.predict([text])[0]
11
+
12
+ # Print result
13
+ result = "disaster" if prediction == 1 else "not"
14
+ print(f"Prediction: {result}")