Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,12 @@
|
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
from langchain_community.llms import
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
|
6 |
# Initialize the chatbot
|
7 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
8 |
-
llm =
|
9 |
repo_id="google/gemma-1.1-7b-it",
|
10 |
task="text-generation",
|
11 |
model_kwargs={
|
@@ -25,7 +26,7 @@ QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],templat
|
|
25 |
|
26 |
def predict(message, history):
|
27 |
input_prompt = QA_CHAIN_PROMPT.format(question=message, context=history)
|
28 |
-
ai_msg = llm.generate(input_prompt)
|
29 |
return ai_msg
|
30 |
|
31 |
gr.ChatInterface(predict).launch()
|
|
|
1 |
+
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
|
7 |
# Initialize the chatbot
|
8 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
9 |
+
llm = HuggingFaceEndpoint(
|
10 |
repo_id="google/gemma-1.1-7b-it",
|
11 |
task="text-generation",
|
12 |
model_kwargs={
|
|
|
26 |
|
27 |
def predict(message, history):
|
28 |
input_prompt = QA_CHAIN_PROMPT.format(question=message, context=history)
|
29 |
+
ai_msg = llm.generate([input_prompt]) # Pass a list of strings
|
30 |
return ai_msg
|
31 |
|
32 |
gr.ChatInterface(predict).launch()
|