File size: 1,867 Bytes
d589ce0
 
3dcb9f7
d589ce0
 
 
 
3dcb9f7
d589ce0
3dcb9f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d589ce0
3dcb9f7
d589ce0
d902198
 
d589ce0
d902198
 
0254726
d902198
d589ce0
8ed0a63
d589ce0
3dcb9f7
 
 
 
 
d589ce0
3dcb9f7
d589ce0
3dcb9f7
 
 
 
 
d589ce0
 
3dcb9f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
css = """

html, body {
    margin: 0;
    padding: 0;
    height: 100%;
    overflow: hidden;
}
body::before {
    content: '';
    position: fixed;
    top: 0;
    left: 0;
    width: 100vw;
    height: 100vh;
    background-image: url('https://png.pngtree.com/background/20230413/original/pngtree-medical-color-cartoon-blank-background-picture-image_2422159.jpg');
    background-size: cover;
    background-repeat: no-repeat;
    opacity: 0.60;    
    background-position: center;
    z-index: -1;    
}
.gradio-container {
    display: flex;
    flex-direction: column;
    justify-content: center;
    align-items: center;
    height: 100vh;  
}

"""

model_id = "harishnair04/Gemma-medtr-2b-sft-v2"
# filename = "Gemma-medtr-2b-sft-v2.gguf"

# tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)
# gemma_model = AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename)
tokenizer = AutoTokenizer.from_pretrained(model_id)
gemma_model = AutoModelForCausalLM.from_pretrained(model_id)

tokenizer.pad_token_id = tokenizer.eos_token_id

def respond(input1):
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    inputs = tokenizer(input1, return_tensors="pt")
    out = gemma_model.generate(**inputs,temperature = 0.4,do_sample=True, max_new_tokens=200)
    return tokenizer.decode(out[0], skip_special_tokens=True)

chat_interface = gr.Interface(
    respond,
    inputs="text",
    outputs="text",
    title="MT CHAT",
    description="Gemma 2b finetuned on medical transcripts",
    css=css
)

chat_interface.launch()