File size: 1,458 Bytes
9ec795d
957486f
 
 
 
9ec795d
aabe5cb
9ec795d
957486f
 
 
 
aabe5cb
957486f
 
9ec795d
aabe5cb
9ec795d
957486f
 
aabe5cb
957486f
9ec795d
 
 
aabe5cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1a4fee
aabe5cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import pipeline
import torch
import random
import gradio as gr

model_path = "shuttie/mistral-nemo-dadjokes-v1"

generator = pipeline(
            task="text-generation",
            model=model_path,
            torch_dtype=torch.bfloat16,
            #device="cuda"
            device_map="auto",
        )

prompt = "[INST] {input} [/INST]"

def make_response(message, history):
    input = prompt.format(input=message)
    generated = generator(input, return_full_text=False, max_new_tokens=64, num_return_sequences=1)
    return generated[0]["generated_text"]


if __name__ == "__main__":
    gr.ChatInterface(
        make_response,
        textbox=gr.Textbox(placeholder="Ask me a question", container=False, scale=7),
        examples=[
            "My Wife gets mad at my pickle puns", 
            "I saw a couple cows smoking as they played poker", 
            "A vegan enters the bar and says", 
            "Accidentally took my cat’s medication.",
            "My grandpa has the heart of a lion",
            "Why can’t dinosaurs laugh?",
            "Why don’t Americans use the metric system?",
            "How do you catch a squirrel?",
            "My girlfriend invited me over to finally meet her cannibal father.",
            "Last time I stayed in a hotel I asked for the porn channel to be disabled"
        ],
        cache_examples=False,
    ).launch(server_name="0.0.0.0",share=True)