kelSidenna commited on
Commit
aaf1d72
·
verified ·
1 Parent(s): b157648

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -59
app.py CHANGED
@@ -1,60 +1,3 @@
1
- # app.py
2
 
3
- import streamlit as st
4
- from unsloth import FastLanguageModel
5
- from transformers import TextStreamer
6
-
7
- # To speed up model loading in repeated queries, you can use st.cache_resource (Streamlit 1.18+).
8
- @st.cache_resource
9
- def load_unsloth_model(
10
- model_name="azizsi/model2",
11
- max_seq_length=4096,
12
- dtype="float16",
13
- load_in_4bit=False
14
- ):
15
- """
16
- Loads and prepares the model for inference using FastLanguageModel from Unsloth.
17
- Returns (model, tokenizer).
18
- """
19
- model, tokenizer = FastLanguageModel.from_pretrained(
20
- model_name=model_name,
21
- max_seq_length=max_seq_length,
22
- dtype=dtype,
23
- load_in_4bit=load_in_4bit
24
- )
25
-
26
- # Enable 2x faster inference (per Unsloth docs)
27
- FastLanguageModel.for_inference(model)
28
-
29
- return model, tokenizer
30
-
31
-
32
- def main():
33
- st.title("Unsloth Model Demo")
34
-
35
- # Provide a text input area for the user
36
- user_input = st.text_area("Enter your prompt:", "")
37
-
38
- # Generate button
39
- if st.button("Generate"):
40
- with st.spinner("Generating response..."):
41
- # Load the model & tokenizer
42
- model, tokenizer = load_unsloth_model()
43
-
44
- # Create a TextStreamer to stream tokens or capture final text
45
- streamer = TextStreamer(tokenizer)
46
-
47
- # Tokenize user prompt and move to GPU (or the model's device)
48
- inputs = tokenizer(user_input, return_tensors="pt").to(model.device)
49
-
50
- # Generate up to 128 new tokens (modify as desired)
51
- outputs = model.generate(**inputs, streamer=streamer, max_new_tokens=128)
52
-
53
- # If you want to display the entire response at once:
54
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
-
56
- st.markdown("**Response:**")
57
- st.write(generated_text)
58
-
59
- if __name__ == "__main__":
60
- main()
 
1
+ import gradio as gr
2
 
3
+ gr.load("azizsi/model2").launch()