Spaces:
Running
Running
File size: 4,600 Bytes
eca6215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import streamlit as st
# from unsloth import FastLanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import os
import sys
# # Suppress unwanted outputs (e.g., from unsloth or other libraries)
# def suppress_output():
# sys.stdout = open(os.devnull, 'w') # Redirect stdout to devnull
# sys.stderr = open(os.devnull, 'w') # Redirect stderr to devnull
# def restore_output():
# sys.stdout = sys.__stdout__ # Restore stdout
# sys.stderr = sys.__stderr__ # Restore stderr
# Load the model (GGUF format)
@st.cache_resource
def load_model():
# Define the repository and model filenames for both the base model and LoRA adapter
base_model_repo = "helamouri/Meta-Llama-3.1-8B-Q8_0.gguf"
base_model_filename = "Meta-Llama-3.1-8B-Q8_0.gguf"
adapter_repo = "helamouri/medichat_assignment"
# adapter_filename = "llama3_medichat.gguf" # assuming adapter is also in safetensors format
adapter_repo = "helamouri/model_medichat_finetuned_v1"
# Download the base model and adapter model to local paths
base_model_path = hf_hub_download(repo_id=base_model_repo, filename=base_model_filename)
adapter_model_path = hf_hub_download(repo_id=adapter_repo, filename=adapter_filename)
# Log paths for debugging
print(f"Base model path: {base_model_path}")
print(f"Adapter model path: {adapter_model_path}")
# Load the full model (base model) and the adapter (LoRA)
try:
model = Llama(model_path=base_model_path) #, adapter_path=adapter_model_path)
print("Model loaded successfully.")
except ValueError as e:
print(f"Error loading model: {e}")
raise
return model
# Generate a response using Llama.cpp
def generate_response(model, prompt):
print('prompt')
print(prompt)
response = model(
prompt,
max_tokens=200, # Maximum tokens for the response
temperature=0.7, # Adjust for creativity (lower = deterministic)
top_p=0.9, # Nucleus sampling
stop=["\n"] # Stop generating when newline is encountered
)
print('response["choices"]')
print(response["choices"])
return response["choices"][0]["text"]
# Load the model and tokenizer (GGUF format)
# @st.cache_resource
# def load_model():
# model_name = "helamouri/model_medichat_finetuned_v1" # Replace with your model's GGUF path
# model = FastLanguageModel.from_pretrained(model_name, device='cpu') # Load the model using unsloth
# tokenizer = model.tokenizer # Assuming the tokenizer is part of the GGUF model object
# return tokenizer, model
# @st.cache_resource
# def load_model():
# model_name = "helamouri/model_medichat_finetuned_v1" # Replace with your model's path
# # Load the tokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# # Load the model (if it's a causal language model or suitable model type)
# model = AutoModelForCausalLM.from_pretrained(model_name,
# device_map="cpu",
# revision="main",
# quantize=False,
# load_in_8bit=False,
# load_in_4bit=False,
# #torch_dtype=torch.float32
# )
# return tokenizer, model
# Suppress unwanted outputs from unsloth or any other libraries during model loading
#suppress_output()
# Load the GGUF model
print('Loading the model')
model = load_model()
# Restore stdout and stderr
#restore_output()
# App layout
print('Setting App layout')
st.title("MediChat: Your AI Medical Consultation Assistant")
st.markdown("Ask me anything about your health!")
st.write("Enter your symptoms or medical questions below:")
# User input
print(f'Setting user interface')
user_input = st.text_input("Your Question:")
if st.button("Get Response"):
if user_input:
with st.spinner("Generating response..."):
# Generate Response
response = generate_response(model, user_input)
print('Response')
print(response)
# Display response
st.text_area("Response:", value=response, height=200)
else:
st.warning("Please enter a question.")
|