hyeongnym commited on
Commit
5e4d10b
ยท
verified ยท
1 Parent(s): 4f7e18c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -53
app.py CHANGED
@@ -33,6 +33,40 @@ logging.basicConfig(
33
  )
34
  logger = logging.getLogger(__name__)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # ์„ค์ • ํด๋ž˜์Šค
37
  class Config:
38
  def __init__(self):
@@ -55,8 +89,29 @@ class ChatResponse(BaseModel):
55
  status: str
56
  timestamp: datetime
57
 
 
 
 
 
 
 
 
 
 
 
58
  # ํŒŒ์ผ ์ฒ˜๋ฆฌ ํด๋ž˜์Šค
59
  class FileProcessor:
 
 
 
 
 
 
 
 
 
 
 
60
  @staticmethod
61
  def process_pdf(file_path):
62
  try:
@@ -78,22 +133,34 @@ class FileProcessor:
78
  @staticmethod
79
  def process_csv(file_path):
80
  try:
81
- encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
82
- for encoding in encodings:
83
- try:
84
- return pd.read_csv(file_path, encoding=encoding)
85
- except UnicodeDecodeError:
86
- continue
87
- raise FileProcessingError("Unable to read CSV with supported encodings")
88
  except Exception as e:
89
  raise FileProcessingError(f"CSV processing error: {str(e)}")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
92
  @torch.no_grad()
93
  def clear_cuda_memory():
94
  if torch.cuda.is_available():
95
  torch.cuda.empty_cache()
96
  gc.collect()
 
 
97
 
98
  # ๋ชจ๋ธ ๋กœ๋“œ
99
  @spaces.GPU
@@ -129,32 +196,19 @@ def find_relevant_context(query, top_k=3):
129
  except Exception as e:
130
  logger.error(f"Context search error: {str(e)}")
131
  return []
132
-
133
  # ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ…
134
  @spaces.GPU
135
  def stream_chat(message: str, history: list, uploaded_file, temperature: float,
136
  max_new_tokens: int, top_p: float, top_k: int, penalty: float) -> Iterator[Tuple[str, list]]:
137
  """
138
  ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ… ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
139
-
140
- Args:
141
- message (str): ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๋ฉ”์‹œ์ง€
142
- history (list): ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ
143
- uploaded_file: ์—…๋กœ๋“œ๋œ ํŒŒ์ผ
144
- temperature (float): ์ƒ์„ฑ ์˜จ๋„
145
- max_new_tokens (int): ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜
146
- top_p (float): ์ƒ์œ„ p ์ƒ˜ํ”Œ๋ง
147
- top_k (int): ์ƒ์œ„ k ์ƒ˜ํ”Œ๋ง
148
- penalty (float): ๋ฐ˜๋ณต ํŽ˜๋„ํ‹ฐ
149
-
150
- Returns:
151
- Iterator[Tuple[str, list]]: ์ƒ์„ฑ๋œ ์‘๋‹ต๊ณผ ์—…๋ฐ์ดํŠธ๋œ ํžˆ์Šคํ† ๋ฆฌ
152
  """
153
- global model, current_file_context
154
 
155
  try:
156
- if model is None:
157
- model = load_model()
 
158
 
159
  logger.info(f'Processing message: {message}')
160
  logger.debug(f'History length: {len(history)}')
@@ -169,9 +223,9 @@ def stream_chat(message: str, history: list, uploaded_file, temperature: float,
169
  elif file_ext == '.csv':
170
  content = FileProcessor.process_csv(uploaded_file.name)
171
  else:
172
- content = safe_file_read(uploaded_file.name)
173
 
174
- file_context = analyze_file_content(content, file_ext)
175
  current_file_context = file_context
176
  except Exception as e:
177
  logger.error(f"File processing error: {str(e)}")
@@ -199,7 +253,16 @@ def stream_chat(message: str, history: list, uploaded_file, temperature: float,
199
  return_tensors="pt"
200
  ).to("cuda")
201
 
202
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
203
 
204
  generate_kwargs = dict(
205
  inputs,
@@ -215,13 +278,14 @@ def stream_chat(message: str, history: list, uploaded_file, temperature: float,
215
 
216
  clear_cuda_memory()
217
 
218
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
219
- thread.start()
 
220
 
221
- buffer = ""
222
- for new_text in streamer:
223
- buffer += new_text
224
- yield "", history + [[message, buffer]]
225
 
226
  clear_cuda_memory()
227
 
@@ -232,8 +296,7 @@ def stream_chat(message: str, history: list, uploaded_file, temperature: float,
232
 
233
  # UI ์ƒ์„ฑ
234
  def create_demo():
235
- with gr.Blocks(css=UPDATED_CSS) as demo:
236
- # UI ์ปดํฌ๋„ŒํŠธ ๊ตฌ์„ฑ
237
  with gr.Column(elem_classes="markdown-style"):
238
  gr.Markdown("""
239
  # ๐Ÿค– RAGOndevice
@@ -244,11 +307,10 @@ def create_demo():
244
  chatbot = gr.Chatbot(
245
  value=[],
246
  height=600,
247
- label="GiniGEN AI Assistant",
248
  elem_classes="chat-container"
249
  )
250
 
251
- # ์ž…๋ ฅ ์ปดํฌ๋„ŒํŠธ
252
  with gr.Row(elem_classes="input-container"):
253
  with gr.Column(scale=1, min_width=70):
254
  file_upload = gr.File(
@@ -283,7 +345,6 @@ def create_demo():
283
  scale=1
284
  )
285
 
286
- # ๊ณ ๊ธ‰ ์„ค์ •
287
  with gr.Accordion("๐ŸŽฎ Advanced Settings", open=False):
288
  with gr.Row():
289
  with gr.Column(scale=1):
@@ -318,26 +379,43 @@ def create_demo():
318
 
319
  # ๋ฉ”์ธ ์‹คํ–‰
320
  if __name__ == "__main__":
321
- # ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
322
- wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
323
- logger.info("Wikipedia dataset loaded")
 
 
 
 
 
 
324
 
325
- # TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™”
326
- questions = wiki_dataset['train']['question'][:10000]
327
- vectorizer = TfidfVectorizer(max_features=1000)
328
- question_vectors = vectorizer.fit_transform(questions)
329
- logger.info("TF-IDF vectorization completed")
330
 
331
- # UI ์‹คํ–‰
332
- demo = create_demo()
333
- demo.launch()
 
 
 
 
334
 
335
  # ํ…Œ์ŠคํŠธ ์ฝ”๋“œ
336
  class TestChatBot(unittest.TestCase):
 
 
 
337
  def test_file_processing(self):
338
- # ํ…Œ์ŠคํŠธ ๊ตฌํ˜„
339
- pass
 
 
340
 
341
  def test_context_search(self):
342
- # ํ…Œ์ŠคํŠธ ๊ตฌํ˜„
343
- pass
 
 
 
33
  )
34
  logger = logging.getLogger(__name__)
35
 
36
+ # ์ „์—ญ ๋ณ€์ˆ˜
37
+ model = None
38
+ tokenizer = None
39
+ current_file_context = None
40
+
41
+ # CSS ์Šคํƒ€์ผ
42
+ CSS = """
43
+ .chat-container {
44
+ height: 600px !important;
45
+ margin-bottom: 10px;
46
+ }
47
+ .input-container {
48
+ height: 70px !important;
49
+ display: flex;
50
+ align-items: center;
51
+ gap: 10px;
52
+ margin-top: 5px;
53
+ }
54
+ .input-textbox {
55
+ height: 70px !important;
56
+ border-radius: 8px !important;
57
+ font-size: 1.1em !important;
58
+ padding: 10px 15px !important;
59
+ }
60
+ .custom-button {
61
+ background: linear-gradient(145deg, #2196f3, #1976d2);
62
+ color: white;
63
+ border-radius: 10px;
64
+ padding: 10px 20px;
65
+ font-weight: 600;
66
+ transition: all 0.3s ease;
67
+ }
68
+ """
69
+
70
  # ์„ค์ • ํด๋ž˜์Šค
71
  class Config:
72
  def __init__(self):
 
89
  status: str
90
  timestamp: datetime
91
 
92
+ def initialize_model_and_tokenizer():
93
+ global model, tokenizer
94
+ try:
95
+ model = load_model()
96
+ tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
97
+ return True
98
+ except Exception as e:
99
+ logger.error(f"Initialization error: {str(e)}")
100
+ return False
101
+
102
  # ํŒŒ์ผ ์ฒ˜๋ฆฌ ํด๋ž˜์Šค
103
  class FileProcessor:
104
+ @staticmethod
105
+ def safe_file_read(file_path):
106
+ encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
107
+ for encoding in encodings:
108
+ try:
109
+ with open(file_path, 'r', encoding=encoding) as f:
110
+ return f.read()
111
+ except UnicodeDecodeError:
112
+ continue
113
+ raise FileProcessingError("Unable to read file with supported encodings")
114
+
115
  @staticmethod
116
  def process_pdf(file_path):
117
  try:
 
133
  @staticmethod
134
  def process_csv(file_path):
135
  try:
136
+ return pd.read_csv(file_path)
 
 
 
 
 
 
137
  except Exception as e:
138
  raise FileProcessingError(f"CSV processing error: {str(e)}")
139
 
140
+ @staticmethod
141
+ def analyze_file_content(content, file_type):
142
+ try:
143
+ if file_type == 'pdf':
144
+ words = len(content.split())
145
+ lines = content.count('\n') + 1
146
+ return f"PDF Analysis:\nWords: {words}\nLines: {lines}"
147
+ elif file_type == 'csv':
148
+ df = pd.DataFrame(content)
149
+ return f"CSV Analysis:\nRows: {len(df)}\nColumns: {len(df.columns)}"
150
+ else:
151
+ lines = content.split('\n')
152
+ return f"Text Analysis:\nLines: {len(lines)}"
153
+ except Exception as e:
154
+ raise FileProcessingError(f"Content analysis error: {str(e)}")
155
+
156
  # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
157
  @torch.no_grad()
158
  def clear_cuda_memory():
159
  if torch.cuda.is_available():
160
  torch.cuda.empty_cache()
161
  gc.collect()
162
+ if model is not None:
163
+ model.cpu()
164
 
165
  # ๋ชจ๋ธ ๋กœ๋“œ
166
  @spaces.GPU
 
196
  except Exception as e:
197
  logger.error(f"Context search error: {str(e)}")
198
  return []
 
199
  # ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ…
200
  @spaces.GPU
201
  def stream_chat(message: str, history: list, uploaded_file, temperature: float,
202
  max_new_tokens: int, top_p: float, top_k: int, penalty: float) -> Iterator[Tuple[str, list]]:
203
  """
204
  ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ… ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  """
206
+ global model, tokenizer, current_file_context
207
 
208
  try:
209
+ if model is None or tokenizer is None:
210
+ if not initialize_model_and_tokenizer():
211
+ raise Exception("Model initialization failed")
212
 
213
  logger.info(f'Processing message: {message}')
214
  logger.debug(f'History length: {len(history)}')
 
223
  elif file_ext == '.csv':
224
  content = FileProcessor.process_csv(uploaded_file.name)
225
  else:
226
+ content = FileProcessor.safe_file_read(uploaded_file.name)
227
 
228
+ file_context = FileProcessor.analyze_file_content(content, file_ext.replace('.', ''))
229
  current_file_context = file_context
230
  except Exception as e:
231
  logger.error(f"File processing error: {str(e)}")
 
253
  return_tensors="pt"
254
  ).to("cuda")
255
 
256
+ # ์ž…๋ ฅ ๊ธธ์ด ์ฒดํฌ
257
+ if len(inputs.input_ids[0]) > config.MAX_TOKENS:
258
+ raise ValueError("Input too long")
259
+
260
+ streamer = TextIteratorStreamer(
261
+ tokenizer,
262
+ timeout=30.0,
263
+ skip_prompt=True,
264
+ skip_special_tokens=True
265
+ )
266
 
267
  generate_kwargs = dict(
268
  inputs,
 
278
 
279
  clear_cuda_memory()
280
 
281
+ with torch.no_grad():
282
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
283
+ thread.start()
284
 
285
+ buffer = ""
286
+ for new_text in streamer:
287
+ buffer += new_text
288
+ yield "", history + [[message, buffer]]
289
 
290
  clear_cuda_memory()
291
 
 
296
 
297
  # UI ์ƒ์„ฑ
298
  def create_demo():
299
+ with gr.Blocks(css=CSS) as demo:
 
300
  with gr.Column(elem_classes="markdown-style"):
301
  gr.Markdown("""
302
  # ๐Ÿค– RAGOndevice
 
307
  chatbot = gr.Chatbot(
308
  value=[],
309
  height=600,
310
+ label="AI Assistant",
311
  elem_classes="chat-container"
312
  )
313
 
 
314
  with gr.Row(elem_classes="input-container"):
315
  with gr.Column(scale=1, min_width=70):
316
  file_upload = gr.File(
 
345
  scale=1
346
  )
347
 
 
348
  with gr.Accordion("๐ŸŽฎ Advanced Settings", open=False):
349
  with gr.Row():
350
  with gr.Column(scale=1):
 
379
 
380
  # ๋ฉ”์ธ ์‹คํ–‰
381
  if __name__ == "__main__":
382
+ try:
383
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
384
+ if not initialize_model_and_tokenizer():
385
+ logger.error("Failed to initialize model and tokenizer")
386
+ exit(1)
387
+
388
+ # ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
389
+ wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
390
+ logger.info("Wikipedia dataset loaded")
391
 
392
+ # TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™”
393
+ questions = wiki_dataset['train']['question'][:10000]
394
+ vectorizer = TfidfVectorizer(max_features=1000)
395
+ question_vectors = vectorizer.fit_transform(questions)
396
+ logger.info("TF-IDF vectorization completed")
397
 
398
+ # UI ์‹คํ–‰
399
+ demo = create_demo()
400
+ demo.launch(share=False, server_name="0.0.0.0")
401
+
402
+ except Exception as e:
403
+ logger.error(f"Application startup error: {str(e)}")
404
+ exit(1)
405
 
406
  # ํ…Œ์ŠคํŠธ ์ฝ”๋“œ
407
  class TestChatBot(unittest.TestCase):
408
+ def setUp(self):
409
+ self.file_processor = FileProcessor()
410
+
411
  def test_file_processing(self):
412
+ # ํŒŒ์ผ ์ฒ˜๋ฆฌ ํ…Œ์ŠคํŠธ
413
+ test_content = "Test content"
414
+ result = self.file_processor.analyze_file_content(test_content, 'txt')
415
+ self.assertIsNotNone(result)
416
 
417
  def test_context_search(self):
418
+ # ์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ
419
+ test_query = "ํ…Œ์ŠคํŠธ ์งˆ๋ฌธ"
420
+ result = find_relevant_context(test_query)
421
+ self.assertIsInstance(result, list)