CesarLeblanc commited on
Commit
6176ef8
1 Parent(s): a7e54a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -18
app.py CHANGED
@@ -4,15 +4,25 @@ from datasets import load_dataset
4
  import requests
5
  from bs4 import BeautifulSoup
6
 
7
- classifier = pipeline("text-classification", model="CesarLeblanc/test_model")
8
- dataset = load_dataset("CesarLeblanc/text_classification_dataset")
 
 
 
 
9
 
10
- def text_classification(text, typology, confidence):
11
- result = classifier(text)
12
- habitat_label = result[0]['label']
13
- habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
14
- habitat_score = result[0]['score']
15
- formatted_output = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%"
 
 
 
 
 
 
16
  floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
17
  response = requests.get(floraveg_url)
18
  if response.status_code == 200:
@@ -21,22 +31,66 @@ def text_classification(text, typology, confidence):
21
  if img_tag:
22
  image_url = img_tag['src']
23
  else:
24
- image_url = 'https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  else:
26
- image_url = 'https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png'
27
- image_output = gr.Image(value=image_url)
28
  return formatted_output, image_output
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  examples=[
31
- ["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 50],
32
- ["thinopyrum junceum, cakile maritima", "EUNIS", 50]
33
  ]
34
 
35
- io = gr.Interface(fn=text_classification,
36
- inputs=[gr.Textbox(lines=2, label="List of comma-separated binomial names of species (see examples)", placeholder="Enter species here..."), gr.Dropdown(["EUNIS"], label="Typology", info="Will add more typologies later!"), gr.Slider(0, 100, value=50, label="Confidence", info="Choose the level of confidence for the prediction")],
37
- outputs=[gr.Textbox(lines=2, label="Vegetation Plot Classification Result"), "image"],
38
- title="Pl@ntBERT",
39
- description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!",
40
  examples=examples)
41
 
42
  io.launch()
 
4
  import requests
5
  from bs4 import BeautifulSoup
6
 
7
+ def return_model(task):
8
+ if task == 'classification':
9
+ model = pipeline("text-classification", model="CesarLeblanc/test_model")
10
+ else:
11
+ model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
12
+ return return_model
13
 
14
+ def return_dataset():
15
+ dataset = load_dataset("CesarLeblanc/text_classification_dataset")
16
+ return dataset
17
+
18
+ def return_text(habitat_label, habitat_score, confidence):
19
+ if habitat_score*100 > confidence:
20
+ text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%."
21
+ else:
22
+ text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%."
23
+ return text
24
+
25
+ def return_image(habitat_label, habitat_score, confidence):
26
  floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
27
  response = requests.get(floraveg_url)
28
  if response.status_code == 200:
 
31
  if img_tag:
32
  image_url = img_tag['src']
33
  else:
34
+ image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
35
+ else:
36
+ image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
37
+ if habitat_score*100 < confidence:
38
+ image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
39
+ image = gr.Image(value=image_url)
40
+ return image
41
+
42
+ def classification(text, typology, confidence, task):
43
+ model = return_model(task)
44
+ dataset = return_dataset()
45
+ result = model(text)
46
+ habitat_label = result[0]['label']
47
+ habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
48
+ habitat_score = result[0]['score']
49
+ formatted_output = return_text(habitat_label, habitat_score, confidence)
50
+ image_output = return_image(habitat_label, habitat_score, confidence)
51
+ return formatted_output, image_output
52
+
53
+ def masking(text, task):
54
+ model = return_model(task)
55
+ text += ', [MASK] [MASK]'
56
+ pred = mask_filler(text, top_k=1)
57
+ text = pred[0]["sequence"]
58
+ image = gr.Image(value="https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png")
59
+ return text, image
60
+
61
+ def plantbert(text, typology, confidence, task):
62
+ if task == "classification":
63
+ formatted_output, image_output = classification(text, typology, confidence, task)
64
  else:
65
+ formatted_output, image_output = masking(text, typology, confidence, task)
 
66
  return formatted_output, image_output
67
 
68
+ inputs=[
69
+ gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here."),
70
+ gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!"),
71
+ gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction.")
72
+ gr.Radio(["classification", "masking"], value="classification", label="Task", info="Which task to choose?"),
73
+ ]
74
+
75
+ outputs=[
76
+ gr.Textbox(lines=2, label="Vegetation Plot Classification Result"),
77
+ "image"
78
+ ]
79
+
80
+ title="Pl@ntBERT"
81
+
82
+ description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!"
83
+
84
  examples=[
85
+ ["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90, "classification"],
86
+ ["thinopyrum junceum, cakile maritima", "EUNIS", 90, "masking"]
87
  ]
88
 
89
+ io = gr.Interface(fn=plantbert,
90
+ inputs=inputs,
91
+ outputs=outputs,
92
+ title=title,
93
+ description=description,
94
  examples=examples)
95
 
96
  io.launch()