Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
|
|
1 |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
import random
|
5 |
from textwrap import wrap
|
6 |
-
import spaces
|
7 |
|
8 |
def wrap_text(text, width=90):
|
9 |
lines = text.split('\n')
|
@@ -51,7 +51,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id = model_id, trust_remote_code
|
|
51 |
# Specify the configuration class for the model
|
52 |
#model_config = AutoConfig.from_pretrained(base_model_id)
|
53 |
|
54 |
-
model =
|
55 |
|
56 |
class ChatBot:
|
57 |
def __init__(self):
|
@@ -64,7 +64,7 @@ class ChatBot:
|
|
64 |
|
65 |
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
|
66 |
# Combine the user's input with the system prompt
|
67 |
-
formatted_input = f"<s>[INST]{
|
68 |
|
69 |
# Encode the formatted input using the tokenizer
|
70 |
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
|
|
|
1 |
+
import spaces
|
2 |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
import random
|
6 |
from textwrap import wrap
|
|
|
7 |
|
8 |
def wrap_text(text, width=90):
|
9 |
lines = text.split('\n')
|
|
|
51 |
# Specify the configuration class for the model
|
52 |
#model_config = AutoConfig.from_pretrained(base_model_id)
|
53 |
|
54 |
+
model = AutoModelForCausalLM.from_pretrained(model_id , torch_dtype=torch.float16 , device_map= "auto" )
|
55 |
|
56 |
class ChatBot:
|
57 |
def __init__(self):
|
|
|
64 |
|
65 |
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
|
66 |
# Combine the user's input with the system prompt
|
67 |
+
formatted_input = f"<s> [INST] {example_instruction} [/INST] {example_answer}</s> [INST] {system_prompt} [/INST]"
|
68 |
|
69 |
# Encode the formatted input using the tokenizer
|
70 |
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
|