Roberta2024 commited on
Commit
02cfa94
·
verified ·
1 Parent(s): 08fe8ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -30
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from PyPDF2 import PdfReader
5
  import google.generativeai as genai
6
  import os
@@ -8,8 +8,8 @@ from langsmith import Client
8
  from ragas.metrics import faithfulness, answer_relevancy, context_relevancy
9
 
10
  # 加載模型
11
- openelm_model = AutoModelForCausalLM.from_pretrained("apple/OpenELM-270M", trust_remote_code=True)
12
- openelm_tokenizer = AutoTokenizer.from_pretrained("apple/OpenELM-270M", trust_remote_code=True) # Add trust_remote_code here
13
 
14
  # Gemini API 設置
15
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
@@ -35,11 +35,11 @@ def gemini_generate(prompt, max_tokens):
35
  return response.text
36
 
37
  def openelm_generate(prompt, max_tokens):
38
- tokenized_prompt = openelm_tokenizer(prompt, return_tensors="pt")
39
  output_ids = openelm_model.generate(
40
- tokenized_prompt["input_ids"],
41
  max_length=max_tokens,
42
- pad_token_id=0,
43
  )
44
  return openelm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
 
@@ -56,42 +56,21 @@ def process_query(pdf_file, llm_choice, query, max_tokens, api_key):
56
  GOOGLE_API_KEY = api_key
57
  genai.configure(api_key=GOOGLE_API_KEY)
58
 
59
- # 從 PDF 提取文本
60
  pdf_path = pdf_file.name
61
  context = extract_text_from_pdf(pdf_path)
62
 
63
- # 根據選擇的 LLM 生成回應
64
  if llm_choice == "Gemini":
65
  response = gemini_generate(f"上下文: {context}\n問題: {query}", max_tokens)
66
  else: # OpenELM
67
  response = openelm_generate(f"上下文: {context}\n問題: {query}", max_tokens)
68
 
69
- # 評估回應
70
  faith_score, ans_rel_score, ctx_rel_score = evaluate_response(response, context, query)
71
 
72
  return response, faith_score, ans_rel_score, ctx_rel_score
73
  except Exception as e:
74
- return str(e), 0, 0, 0 # 返回錯誤消息和零分數
75
 
76
- # Gradio 介面
77
- iface = gr.Interface(
78
- fn=process_query,
79
- inputs=[
80
- gr.File(label="上傳 PDF"),
81
- gr.Dropdown(["Gemini", "OpenELM"], label="選擇 LLM"),
82
- gr.Textbox(label="輸入您的問題"),
83
- gr.Slider(minimum=50, maximum=1000, step=50, label="最大令牌數"),
84
- gr.Textbox(label="Gemini API Key (可選)", type="password")
85
- ],
86
- outputs=[
87
- gr.Textbox(label="生成的答案"),
88
- gr.Number(label="真實性得分"),
89
- gr.Number(label="答案相關性得分"),
90
- gr.Number(label="上下文相關性得分")
91
- ],
92
- title="多模型 LLM 查詢介面,支持 PDF 上下文",
93
- description="上傳 PDF,選擇 LLM,並提出問題。回應將使用 RAGAS 指標進行評估。"
94
- )
95
 
96
  if __name__ == "__main__":
97
- iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer
4
  from PyPDF2 import PdfReader
5
  import google.generativeai as genai
6
  import os
 
8
  from ragas.metrics import faithfulness, answer_relevancy, context_relevancy
9
 
10
  # 加載模型
11
+ openelm_model = AutoModelForCausalLM.from_pretrained("apple/OpenELM-270M", revision="main", trust_remote_code=True)
12
+ openelm_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # OpenELM 使用 GPT2 tokenizer
13
 
14
  # Gemini API 設置
15
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
 
35
  return response.text
36
 
37
  def openelm_generate(prompt, max_tokens):
38
+ input_ids = openelm_tokenizer.encode(prompt, return_tensors="pt")
39
  output_ids = openelm_model.generate(
40
+ input_ids,
41
  max_length=max_tokens,
42
+ pad_token_id=openelm_tokenizer.eos_token_id
43
  )
44
  return openelm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
 
 
56
  GOOGLE_API_KEY = api_key
57
  genai.configure(api_key=GOOGLE_API_KEY)
58
 
 
59
  pdf_path = pdf_file.name
60
  context = extract_text_from_pdf(pdf_path)
61
 
 
62
  if llm_choice == "Gemini":
63
  response = gemini_generate(f"上下文: {context}\n問題: {query}", max_tokens)
64
  else: # OpenELM
65
  response = openelm_generate(f"上下文: {context}\n問題: {query}", max_tokens)
66
 
 
67
  faith_score, ans_rel_score, ctx_rel_score = evaluate_response(response, context, query)
68
 
69
  return response, faith_score, ans_rel_score, ctx_rel_score
70
  except Exception as e:
71
+ return str(e), 0, 0, 0
72
 
73
+ # Gradio 界面設置保持不變...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
+ iface.launch()