ArunAIML commited on
Commit
f0e8cfd
·
1 Parent(s): 0ad6d0c

nvdia llm mistral

Browse files
Files changed (2) hide show
  1. app.py +17 -14
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,21 +1,24 @@
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
2
  import gradio as gr
 
3
 
4
- model_id = "gpt2"
5
 
6
- tokenizer = AutoTokenizer.from_pretrained(model_id)
7
- model = AutoModelForCausalLM.from_pretrained(model_id)
8
- model.to_bettertransformer()
9
 
10
- pipe = pipeline("text-generation", model=model, tokenizer= tokenizer)
 
11
 
12
- def gpt(prompt, top_k, penalty_alpha):
13
- return pipe(prompt, top_k=top_k, penalty_alpha=penalty_alpha)[0]["generated_text"]
14
 
15
- gr.Interface(
16
- gpt,
17
- ["text",gr.Slider(minimum=0, maximum=50, step=1,label="Top_k"),gr.Slider(minimum=0.1, maximum=1.0,label="penalty_alpha")],
18
- "text",
19
- title= "Arun's GPT chatbot",
20
- description = "This is Arun's experimental GPT interface exposing gpt2, feel free to experiment"
21
- ).launch()
 
 
 
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_core.prompts import ChatPromptTemplate
5
  import gradio as gr
6
+ import os
7
 
8
+ os.environ["NVIDIA_API_KEY"] = "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"
9
 
10
+ prompt = ChatPromptTemplate.from_messages([("system", "You are a helpful AI assistant named Fred."), ("user", "{input}")])
 
 
11
 
12
+ llm = ChatNVIDIA(model="mixtral_8x7b")
13
+ chain = prompt | llm | StrOutputParser()
14
 
 
 
15
 
16
+ def chat(prompt, history):
17
+
18
+ for chunk in chain.stream({"input": prompt}):
19
+ yield chunk.content
20
+
21
+ gr.Chat
22
+ demo = gr.ChatInterface(chat).queue()
23
+
24
+ demo.launch()
requirements.txt CHANGED
@@ -2,3 +2,5 @@ transformers
2
  gradio
3
  torch
4
  optimum
 
 
 
2
  gradio
3
  torch
4
  optimum
5
+ langchain
6
+ langchain-nvidia-ai-endpoints