Jesus Andres Correal Ortiz commited on
Commit
0a0007f
·
1 Parent(s): 3719454

Updated app.py to fix error in dbtypes

Browse files
Files changed (2) hide show
  1. app.py +15 -3
  2. fine-tuning.ipynb +5 -0
app.py CHANGED
@@ -7,21 +7,33 @@ model_name = "acorreal/phi3-project-management"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
 
 
 
 
 
10
  # Streamlit app
11
  st.title('Project Management Educational Tutor')
12
- st.write('This app uses the "acorreal/phi3-project-management" model')
 
13
 
14
  user_input = st.text_area("Enter your project management question or topic here:")
15
 
16
  if st.button('Get Response'):
17
  if user_input:
18
- inputs = tokenizer(user_input, return_tensors="pt")
 
 
 
 
 
 
19
  with torch.no_grad():
20
  outputs = model(**inputs)
21
  logits = outputs.logits
22
  predicted_class_id = logits.argmax().item()
23
 
24
  st.write(f"Predicted class ID: {predicted_class_id}")
25
- # You can add more logic here to provide detailed responses based on the predicted_class_id
26
  else:
27
  st.write("Please enter a question or topic to get a response.")
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
+ # Ensure the model is on the correct device and using the right dtype
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+ model.eval() # Set model to evaluation mode
14
+
15
  # Streamlit app
16
  st.title('Project Management Educational Tutor')
17
+
18
+ st.write('This app uses the "acorreal/phi3-project-management" model to provide insights on project management topics.')
19
 
20
  user_input = st.text_area("Enter your project management question or topic here:")
21
 
22
  if st.button('Get Response'):
23
  if user_input:
24
+
25
+ # Tokenize the input and move it to the correct device
26
+ inputs = tokenizer(user_input, return_tensors="pt").to(device)
27
+
28
+ # Ensure inputs are in the correct dtype
29
+ inputs = {k: v.to(dtype=torch.float32 if model.dtype == torch.float32 else torch.float16) for k, v in inputs.items()}
30
+
31
  with torch.no_grad():
32
  outputs = model(**inputs)
33
  logits = outputs.logits
34
  predicted_class_id = logits.argmax().item()
35
 
36
  st.write(f"Predicted class ID: {predicted_class_id}")
37
+
38
  else:
39
  st.write("Please enter a question or topic to get a response.")
fine-tuning.ipynb CHANGED
@@ -71,8 +71,13 @@
71
  },
72
  "outputs": [],
73
  "source": [
 
74
  "base_model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n",
 
 
75
  "model_name=\"acorreal/project-management-tutor\"\n",
 
 
76
  "use_4bit = True\n",
77
  "bnb_4bit_quant_type = \"nf4\"\n",
78
  "use_double_quant = True\n",
 
71
  },
72
  "outputs": [],
73
  "source": [
74
+ "# Name of the model to use as parent model\n",
75
  "base_model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n",
76
+ "\n",
77
+ "# Name of the new model\n",
78
  "model_name=\"acorreal/project-management-tutor\"\n",
79
+ "\n",
80
+ "# Set the model configuration\n",
81
  "use_4bit = True\n",
82
  "bnb_4bit_quant_type = \"nf4\"\n",
83
  "use_double_quant = True\n",