DivEye - PR (fix pickling)

#9
by FloofCat - opened
Files changed (2) hide show
  1. app.py +1 -1
  2. software.py +33 -10
app.py CHANGED
@@ -51,7 +51,7 @@ def detect_ai_text(text):
51
  return message, round(ai_prob, 3), bar_data
52
 
53
  # Gradio app setup
54
- with gr.Blocks(title="DivEye", theme=theme) as demo:
55
  gr.HTML("""
56
  <div style="display: flex; justify-content: space-between; align-items: center; padding: 1.5rem; background: #f0f4f8; border-radius: 12px; margin-bottom: 1rem;">
57
  <div style="text-align: left; max-width: 70%;">
 
51
  return message, round(ai_prob, 3), bar_data
52
 
53
  # Gradio app setup
54
+ with gr.Blocks(title="DivEye") as demo:
55
  gr.HTML("""
56
  <div style="display: flex; justify-content: space-between; align-items: center; padding: 1.5rem; background: #f0f4f8; border-radius: 12px; margin-bottom: 1rem;">
57
  <div style="text-align: left; max-width: 70%;">
software.py CHANGED
@@ -99,22 +99,42 @@ class BiScope:
99
  class Software:
100
  def __init__(self):
101
  self.token = os.getenv("HF_TOKEN")
 
 
102
 
103
- self.div_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", use_fast=False, trust_remote_code=True, use_auth_token=self.token)
104
- self.div_model = AutoModelForCausalLM.from_pretrained(
105
- "tiiuae/falcon-7b", torch_dtype=torch.float16, trust_remote_code=True, use_auth_token=self.token
106
- )
107
-
108
- self.bi_tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it", use_fast=False, trust_remote_code=True, use_auth_token=self.token)
109
- self.bi_model = AutoModelForCausalLM.from_pretrained(
110
- "google/gemma-1.1-2b-it", torch_dtype=torch.float16, trust_remote_code=True, use_auth_token=self.token
111
- )
112
-
113
 
114
  self.model_path = Path(__file__).parent / "model.json"
115
 
116
  self.model = xgb.XGBClassifier()
117
  self.model.load_model(self.model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def load_data(self, jsonl_path):
120
  ids, texts = [], []
@@ -127,6 +147,9 @@ class Software:
127
 
128
  @spaces.GPU
129
  def evaluate(self, text):
 
 
 
130
  # Load models to GPUs.
131
  device_div = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
132
  if torch.cuda.device_count() > 1:
 
99
  class Software:
100
  def __init__(self):
101
  self.token = os.getenv("HF_TOKEN")
102
+ self.device_div = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
103
+ self.device_bi = self.device_div
104
 
105
+ self.div_model = None
106
+ self.div_tokenizer = None
107
+ self.bi_model = None
108
+ self.bi_tokenizer = None
 
 
 
 
 
 
109
 
110
  self.model_path = Path(__file__).parent / "model.json"
111
 
112
  self.model = xgb.XGBClassifier()
113
  self.model.load_model(self.model_path)
114
+
115
+ def _load_div_models(self):
116
+ if self.div_model is None or self.div_tokenizer is None:
117
+ self.div_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", use_fast=False, trust_remote_code=True, use_auth_token=self.token)
118
+ self.div_model = AutoModelForCausalLM.from_pretrained(
119
+ "tiiuae/falcon-7b",
120
+ device_map="cuda",
121
+ torch_dtype=torch.float16,
122
+ trust_remote_code=True,
123
+ use_auth_token=self.token
124
+ )
125
+ self.div_model.to(self.device_div)
126
+
127
+ def _load_bi_models(self):
128
+ if self.bi_model is None or self.bi_tokenizer is None:
129
+ self.bi_tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it", use_fast=False, trust_remote_code=True, use_auth_token=self.token)
130
+ self.bi_model = AutoModelForCausalLM.from_pretrained(
131
+ "google/gemma-1.1-2b-it",
132
+ device_map="cuda",
133
+ torch_dtype=torch.float16,
134
+ trust_remote_code=True,
135
+ use_auth_token=self.token
136
+ )
137
+ self.bi_model.to(self.device_bi)
138
 
139
  def load_data(self, jsonl_path):
140
  ids, texts = [], []
 
147
 
148
  @spaces.GPU
149
  def evaluate(self, text):
150
+ self._load_div_models()
151
+ self._load_bi_models()
152
+
153
  # Load models to GPUs.
154
  device_div = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
155
  if torch.cuda.device_count() > 1: