CannaTech commited on
Commit
70760e5
·
verified ·
1 Parent(s): b9b3da5

Suck it bitches

Browse files
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -1,31 +1,42 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
3
 
4
- # Load the pre-trained model and tokenizer
5
- model_name = "distilbert-base-uncased-distilled-squad"
6
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- # Load custom knowledge document
10
- with open("knowledge.txt", "r") as file:
11
- custom_knowledge = file.read()
12
 
13
- # Initialize the question-answering pipeline
14
  qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
15
 
 
 
 
 
 
16
  def answer_question(question):
17
- result = qa_pipeline(question=question, context=custom_knowledge)
18
- return result['answer']
 
 
 
 
 
19
 
20
- # Set up the Gradio interface
21
- interface = gr.Interface(
22
  fn=answer_question,
23
- inputs="text",
24
  outputs="text",
25
- title="Custom Knowledge QA",
26
- description="Ask questions based on the custom knowledge document."
 
 
 
27
  )
28
 
 
29
  if __name__ == "__main__":
30
- interface.launch()
31
-
 
1
  import gradio as gr
2
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
3
 
4
+ # 1. Choose a bilingual or multilingual QA model
5
+ MODEL_NAME = "mrm8488/xlm-roberta-large-finetuned-squadv2"
 
 
6
 
7
+ # 2. Load model + tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
10
 
11
+ # 3. Initialize QA pipeline
12
  qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
13
 
14
+ # 4. Load or define custom knowledge base
15
+ with open("knowledge.txt", "r", encoding="utf-8") as f:
16
+ knowledge_text = f.read()
17
+
18
+ # 5. Define function to answer questions
19
  def answer_question(question):
20
+ if not question.strip():
21
+ return "Please ask a valid question."
22
+ try:
23
+ result = qa_pipeline(question=question, context=knowledge_text)
24
+ return result["answer"]
25
+ except Exception as e:
26
+ return f"Error: {str(e)}"
27
 
28
+ # 6. Build Gradio interface
29
+ iface = gr.Interface(
30
  fn=answer_question,
31
+ inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
32
  outputs="text",
33
+ title="Budtender LLM (Bilingual QA)",
34
+ description=(
35
+ "A bilingual Q&A model trained on Spanish and English data. "
36
+ "Ask your cannabis-related questions here!"
37
+ )
38
  )
39
 
40
+ # 7. Launch app
41
  if __name__ == "__main__":
42
+ iface.launch()