CannaTech commited on
Commit
5031659
·
1 Parent(s): 86edf42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -8,19 +8,31 @@ import json
8
  # Set the OpenAI API key
9
  openai.api_key = os.getenv("OPENAI_API_KEY")
10
 
11
- # Set up flagging callback function
 
 
 
 
12
  class CustomHuggingFaceDatasetSaver(gr.FlaggingCallback):
13
- def __init__(self, hf_token, hf_dataset, max_output_length=100):
14
- super().__init__(hf_token, hf_dataset)
15
- self.max_output_length = max_output_length
16
 
17
- def flag(self, inputs, outputs, predicted_outputs):
18
- if predicted_outputs is not None:
19
- predicted_outputs = predicted_outputs[:self.max_output_length]
20
- super().flag(inputs, outputs, predicted_outputs)
 
21
 
22
- HF_TOKEN = os.getenv("HF_TOKEN")
23
- hf_writer = CustomHuggingFaceDatasetSaver(HF_TOKEN, "CannaTech/Flagged")
 
 
 
 
 
 
 
 
24
 
25
  # Define the authentication function
26
  def check_auth(username, password):
 
8
  # Set the OpenAI API key
9
  openai.api_key = os.getenv("OPENAI_API_KEY")
10
 
11
+ # Set up the Hugging Face Dataset Saver
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+ hf_writer = CustomHuggingFaceDatasetSaver(HF_TOKEN, "CannaTech/Flagged")
14
+
15
+ # Define the CustomHuggingFaceDatasetSaver class
16
  class CustomHuggingFaceDatasetSaver(gr.FlaggingCallback):
17
+ def setup(self):
18
+ pass
 
19
 
20
+ def apply(self, flagged_samples):
21
+ for sample in flagged_samples:
22
+ input_data = sample["input"]
23
+ output_data = sample["output"]
24
+ flag = sample["flag"]
25
 
26
+ # Truncate the output if it exceeds 1000 characters
27
+ if len(output_data) > 1000:
28
+ output_data = output_data[:900] + " [Truncated]"
29
+
30
+ # Save the data to Hugging Face dataset
31
+ # Replace this placeholder logic with your own implementation
32
+ # Example: Saving the data to a CSV file
33
+ with open("flagged_data.csv", "a") as f:
34
+ writer = csv.writer(f)
35
+ writer.writerow([input_data, output_data, flag])
36
 
37
  # Define the authentication function
38
  def check_auth(username, password):