Zmorell commited on
Commit
83f48db
·
verified ·
1 Parent(s): 1145787

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -20,10 +20,14 @@ nlp = spacy.load('en_core_web_sm')
20
  # Available backend options are: "jax", "torch", "tensorflow".
21
  import os
22
  os.environ["KERAS_BACKEND"] = "jax"
23
-
 
24
  import keras
25
 
26
- model = keras.saving.load_model("hf://ARI-HIPA-AI-Team/keras_model")
 
 
 
27
 
28
  def preprocess_text(text):
29
  text = re.sub(r'[^a-zA-Z0-9\s]', '', text) # Only remove non-alphanumeric characters except spaces
@@ -39,8 +43,10 @@ def preprocess_text(text):
39
 
40
  def predict(text):
41
  inputs = preprocess_text(text)
 
 
42
  outputs = model(inputs)
43
- return "This text is a violation = " + outputs
44
 
45
  demo = gr.Interface(fn=predict, inputs="text", outputs="text")
46
- demo.launch()
 
20
  # Available backend options are: "jax", "torch", "tensorflow".
21
  import os
22
  os.environ["KERAS_BACKEND"] = "jax"
23
+
24
+ # Ensure the necessary libraries are correctly imported
25
  import keras
26
 
27
+ # Load the model from the Hugging Face repository
28
+ model_path = "https://huggingface.co/Zmorell/HIPA_2/resolve/main/saved_keras_model.keras"
29
+ model = tf.keras.models.load_model(model_path)
30
+ print(f"Model loaded from {model_path}")
31
 
32
  def preprocess_text(text):
33
  text = re.sub(r'[^a-zA-Z0-9\s]', '', text) # Only remove non-alphanumeric characters except spaces
 
43
 
44
  def predict(text):
45
  inputs = preprocess_text(text)
46
+ # Ensure the input shape matches what the model expects
47
+ inputs = tf.convert_to_tensor([inputs])
48
  outputs = model(inputs)
49
+ return "This text is a violation = " + str(outputs[0][0].numpy())
50
 
51
  demo = gr.Interface(fn=predict, inputs="text", outputs="text")
52
+ demo.launch()