JanviMl commited on
Commit
39459c9
·
verified ·
1 Parent(s): 558e729

Update refine_paraphrases.py

Browse files
Files changed (1) hide show
  1. refine_paraphrases.py +21 -6
refine_paraphrases.py CHANGED
@@ -3,10 +3,12 @@ import pandas as pd
3
  from paraphraser import paraphrase_comment
4
  from metrics import compute_reward_scores
5
  from model_loader import paraphraser_model
 
 
6
 
7
  # Configuration
8
- DATA_PATH = "toxic-comment-classifier_rlhf/refined_paraphrases.csv"
9
- OUTPUT_PATH = "toxic-comment-classifier_rlhf/iterated_paraphrases.csv"
10
  MAX_ITERATIONS = 3
11
  TARGET_SCORES = {
12
  "empathy": 0.9,
@@ -83,8 +85,12 @@ def refine_paraphrase(row: pd.Series) -> tuple:
83
  return current_paraphrase, current_scores, "; ".join(reasoning)
84
 
85
  def main():
86
- # Load dataset
87
- df = pd.read_csv(DATA_PATH)
 
 
 
 
88
 
89
  # Process each row
90
  results = []
@@ -105,11 +111,20 @@ def main():
105
  "Iteration_Reasoning": reasoning
106
  }
107
  results.append(result)
108
-
109
- # Save results
110
  result_df = pd.DataFrame(results)
111
  result_df.to_csv(OUTPUT_PATH, index=False)
112
  print(f"Refinement complete. Results saved to {OUTPUT_PATH}")
 
 
 
 
 
 
 
 
 
113
 
114
  if __name__ == "__main__":
115
  main()
 
3
  from paraphraser import paraphrase_comment
4
  from metrics import compute_reward_scores
5
  from model_loader import paraphraser_model
6
+ from datasets import load_dataset
7
+ import os
8
 
9
  # Configuration
10
+ DATA_PATH = "JanviMl/toxi_refined_paraphrases"
11
+ OUTPUT_PATH = "iterated_paraphrases.csv"
12
  MAX_ITERATIONS = 3
13
  TARGET_SCORES = {
14
  "empathy": 0.9,
 
85
  return current_paraphrase, current_scores, "; ".join(reasoning)
86
 
87
  def main():
88
+ # Load dataset from Hugging Face Hub
89
+ try:
90
+ df = load_dataset(DATA_PATH, split="train").to_pandas()
91
+ except Exception as e:
92
+ print(f"Error loading dataset: {str(e)}")
93
+ return
94
 
95
  # Process each row
96
  results = []
 
111
  "Iteration_Reasoning": reasoning
112
  }
113
  results.append(result)
114
+
115
+ # Save results locally
116
  result_df = pd.DataFrame(results)
117
  result_df.to_csv(OUTPUT_PATH, index=False)
118
  print(f"Refinement complete. Results saved to {OUTPUT_PATH}")
119
+
120
+ # Push to Hugging Face Hub
121
+ try:
122
+ from datasets import Dataset
123
+ dataset = Dataset.from_pandas(result_df)
124
+ dataset.push_to_hub("JanviMl/toxi_iterated_paraphrases", token=os.getenv("HF_TOKEN"))
125
+ print("Pushed to Hugging Face Hub: JanviMl/toxi_iterated_paraphrases")
126
+ except Exception as e:
127
+ print(f"Error pushing to Hub: {str(e)}")
128
 
129
  if __name__ == "__main__":
130
  main()