Moha782 commited on
Commit
bca5017
·
verified ·
1 Parent(s): 39f6f38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -26
app.py CHANGED
@@ -1,11 +1,17 @@
1
  import gradio as gr
 
2
  from pathlib import Path
3
- from transformers import RagTokenForGeneration, AutoTokenizer, AutoModelForCausalLM
4
  from pdfplumber import open as open_pdf
5
  from typing import List
6
 
 
 
 
 
 
7
  # Load the PDF file
8
- pdf_path = Path("apexcustoms.pdf")
9
  with open_pdf(pdf_path) as pdf:
10
  text = "\n".join(page.extract_text() for page in pdf.pages)
11
 
@@ -15,19 +21,14 @@ text_chunks: List[str] = [text[i:i+chunk_size] for i in range(0, len(text), chun
15
 
16
  # Load the RAG model and tokenizer for retrieval
17
  rag_tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
18
- rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
19
-
20
- # Load the DialoGPT model and tokenizer for generation
21
- dialogpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
22
- dialogpt_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
23
 
24
  def respond(
25
  message,
26
  history: list[tuple[str, str]],
27
  system_message,
28
  max_tokens,
29
- num_beams,
30
- no_repeat_ngram_size,
31
  ):
32
  messages = [{"role": "system", "content": system_message}]
33
 
@@ -46,21 +47,16 @@ def respond(
46
  rag_output = rag_model(rag_input_ids, text_chunks, return_retrieved_inputs=True)
47
  retrieved_text = rag_output.retrieved_inputs
48
 
49
- # Encode the context and user's message for DialoGPT
50
- input_ids = dialogpt_tokenizer.encode(retrieved_text + "\n\n" + message, return_tensors="pt")
51
-
52
- # Generate the response using the DialoGPT model
53
- output = dialogpt_model.generate(
54
- input_ids,
55
- max_length=max_tokens,
56
- num_beams=num_beams,
57
- no_repeat_ngram_size=no_repeat_ngram_size,
58
- early_stopping=True
59
- )
60
-
61
- response = dialogpt_tokenizer.decode(output[0], skip_special_tokens=True)
62
-
63
- yield response
64
 
65
  """
66
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -70,8 +66,6 @@ demo = gr.ChatInterface(
70
  additional_inputs=[
71
  gr.Textbox(value="You are a helpful car configuration assistant, specifically you are the assistant for Apex Customs (https://www.apexcustoms.com/). Given the user's input, provide suggestions for car models, colors, and customization options. Be conversational in your responses. You should remember the user car model and tailor your answers accordingly. You limit yourself to answering the given question and maybe propose a suggestion but not write the next question of the user. \n\nUser: ", label="System message"),
72
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
73
- gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Number of beams"),
74
- gr.Slider(minimum=1, maximum=5, value=2, step=1, label="No repeat ngram size"),
75
  ],
76
  )
77
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  from pathlib import Path
4
+ from transformers import RagTokenForGeneration, RagTokenizer, RagRetriever
5
  from pdfplumber import open as open_pdf
6
  from typing import List
7
 
8
+ """
9
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
+ """
11
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
+
13
  # Load the PDF file
14
+ pdf_path = Path("path/to/your/pdf/file.pdf")
15
  with open_pdf(pdf_path) as pdf:
16
  text = "\n".join(page.extract_text() for page in pdf.pages)
17
 
 
21
 
22
  # Load the RAG model and tokenizer for retrieval
23
  rag_tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
24
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
25
+ rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
 
 
 
26
 
27
  def respond(
28
  message,
29
  history: list[tuple[str, str]],
30
  system_message,
31
  max_tokens,
 
 
32
  ):
33
  messages = [{"role": "system", "content": system_message}]
34
 
 
47
  rag_output = rag_model(rag_input_ids, text_chunks, return_retrieved_inputs=True)
48
  retrieved_text = rag_output.retrieved_inputs
49
 
50
+ # Generate the response using the zephyr model
51
+ for message in client.chat_completion(
52
+ messages,
53
+ max_tokens=max_tokens,
54
+ stream=True,
55
+ files={"context": retrieved_text}, # Pass retrieved text as context
56
+ ):
57
+ token = message.choices[0].delta.content
58
+ response += token
59
+ yield response
 
 
 
 
 
60
 
61
  """
62
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
66
  additional_inputs=[
67
  gr.Textbox(value="You are a helpful car configuration assistant, specifically you are the assistant for Apex Customs (https://www.apexcustoms.com/). Given the user's input, provide suggestions for car models, colors, and customization options. Be conversational in your responses. You should remember the user car model and tailor your answers accordingly. You limit yourself to answering the given question and maybe propose a suggestion but not write the next question of the user. \n\nUser: ", label="System message"),
68
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
 
 
69
  ],
70
  )
71