mery22 commited on
Commit
b4e5268
·
verified ·
1 Parent(s): b48a9c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py CHANGED
@@ -1,4 +1,119 @@
 
 
 
 
 
 
 
 
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import gradio as gr
4
  def qna_chatbot(message, history):
 
1
+ import os
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ BitsAndBytesConfig,
7
+ pipeline
8
+ )
9
 
10
+ from transformers import BitsAndBytesConfig
11
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
12
+ from langchain.vectorstores import FAISS
13
+
14
+ from langchain.prompts import PromptTemplate
15
+ from langchain.schema.runnable import RunnablePassthrough
16
+ from langchain.llms import HuggingFacePipeline
17
+ from langchain.chains import LLMChain
18
+ import transformers
19
+ model_name='mistralai/Mistral-7B-Instruct-v0.1'
20
+
21
+ model_config = transformers.AutoConfig.from_pretrained(
22
+ model_name,
23
+ )
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+ tokenizer.padding_side = "right"
28
+ #################################################################
29
+ # bitsandbytes parameters
30
+ #################################################################
31
+
32
+ # Activate 4-bit precision base model loading
33
+ use_4bit = True
34
+
35
+ # Compute dtype for 4-bit base models
36
+ bnb_4bit_compute_dtype = "float16"
37
+
38
+ # Quantization type (fp4 or nf4)
39
+ bnb_4bit_quant_type = "nf4"
40
+
41
+ # Activate nested quantization for 4-bit base models (double quantization)
42
+ use_nested_quant = False
43
+ #################################################################
44
+ # Set up quantization config
45
+ #################################################################
46
+ compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
47
+
48
+ bnb_config = BitsAndBytesConfig(
49
+ load_in_4bit=use_4bit,
50
+ bnb_4bit_quant_type=bnb_4bit_quant_type,
51
+ bnb_4bit_compute_dtype=compute_dtype,
52
+ bnb_4bit_use_double_quant=use_nested_quant,
53
+ )
54
+ #############################################################
55
+ # Load pre-trained config
56
+ #################################################################
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_name,
59
+ quantization_config=bnb_config,
60
+ )
61
+ # Connect query to FAISS index using a retriever
62
+ retriever = db.as_retriever(
63
+ search_type="mmr",
64
+ search_kwargs={'k': 1}
65
+ )
66
+ from langchain.llms import HuggingFacePipeline
67
+ from langchain.prompts import PromptTemplate
68
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
69
+
70
+ text_generation_pipeline = transformers.pipeline(
71
+ model=model,
72
+ tokenizer=tokenizer,
73
+ task="text-generation",
74
+
75
+ temperature=0.02,
76
+ repetition_penalty=1.1,
77
+ return_full_text=True,
78
+ max_new_tokens=512,
79
+ )
80
+
81
+ prompt_template = """
82
+ ### [INST]
83
+ Instruction: You are a Q&A assistant. Your goal is to answer questions as accurately as possible based on the instructions and context provided without using prior knowledge.You answer in FRENCH
84
+ Analyse carefully the context and provide a direct answer based on the context.
85
+ Answer in french only
86
+ {context}
87
+ Vous devez répondre aux questions en français.
88
+
89
+ ### QUESTION:
90
+ {question}
91
+ [/INST]
92
+ Answer in french only
93
+ Vous devez répondre aux questions en français.
94
+
95
+ """
96
+
97
+ mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
98
+
99
+ # Create prompt from prompt template
100
+ prompt = PromptTemplate(
101
+ input_variables=["question"],
102
+ template=prompt_template,
103
+ )
104
+
105
+ # Create llm chain
106
+ llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)
107
+ from langchain.chains import RetrievalQA
108
+
109
+
110
+ retriever.search_kwargs = {'k':1}
111
+ qa = RetrievalQA.from_chain_type(
112
+ llm=mistral_llm,
113
+ chain_type="stuff",
114
+ retriever=retriever,
115
+ chain_type_kwargs={"prompt": prompt},
116
+ )
117
 
118
  import gradio as gr
119
  def qna_chatbot(message, history):