Canstralian commited on
Commit
31fed9d
·
verified ·
1 Parent(s): ae0d601

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -146
app.py CHANGED
@@ -1,148 +1,88 @@
1
- """Inspired by the SantaCoder demo Huggingface space.
2
- Link: https://huggingface.co/spaces/bigcode/santacoder-demo/tree/main/app.py
3
- """
4
-
5
- import os
6
- import gradio as gr
7
- import torch
8
-
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
10
-
11
- REPO = "replit/replit-code-v1-3b"
12
-
13
- description = """# <h1 style="text-align: center; color: white;"><span style='color: #F26207;'> Code Completion with replit-code-v1-3b </h1>
14
- <span style="color: white; text-align: center;"> replit-code-v1-3b model is a 2.7B LLM trained on 20 languages from the Stack Dedup v1.2 dataset. You can click the button several times to keep completing your code.</span>"""
15
-
16
-
17
- token = os.environ["HUB_TOKEN"]
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- PAD_TOKEN = "<|pad|>"
21
- EOS_TOKEN = "<|endoftext|>"
22
- UNK_TOKEN = "<|unk|>"
23
- MAX_INPUT_TOKENS = 1024 # max tokens from context
24
-
25
-
26
- tokenizer = AutoTokenizer.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True)
27
- tokenizer.truncation_side = "left" # ensures if truncate, then keep the last N tokens of the prompt going L -> R
28
-
29
- if device == "cuda":
30
- model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True).to(device, dtype=torch.bfloat16)
31
- else:
32
- model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True)
33
-
34
- model.eval()
35
-
36
-
37
- custom_css = """
38
- .gradio-container {
39
- background-color: #0D1525;
40
- color:white
41
- }
42
- #orange-button {
43
- background: #F26207 !important;
44
- color: white;
45
- }
46
- .cm-gutters{
47
- border: none !important;
48
- }
49
- """
50
-
51
- def post_processing(prompt, completion):
52
- return prompt + completion
53
- # completion = "<span style='color: #499cd5;'>" + completion + "</span>"
54
- # prompt = "<span style='color: black;'>" + prompt + "</span>"
55
- # code_html = f"<hr><br><pre style='font-size: 14px'><code>{prompt}{completion}</code></pre><br><hr>"
56
- # return code_html
57
-
58
-
59
- def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
60
-
61
- # truncates the prompt to MAX_INPUT_TOKENS if its too long
62
- x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
63
- print("Prompt shape: ", x.shape) # just adding to see in the space logs in prod
64
- set_seed(seed)
65
- y = model.generate(x,
66
- max_new_tokens=max_new_tokens,
67
- temperature=temperature,
68
- pad_token_id=tokenizer.pad_token_id,
69
- eos_token_id=tokenizer.eos_token_id,
70
- top_p=top_p,
71
- top_k=top_k,
72
- use_cache=use_cache,
73
- repetition_penalty=repetition_penalty
74
- )
75
- completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
76
- completion = completion[len(prompt):]
77
- return post_processing(prompt, completion)
78
-
79
-
80
- demo = gr.Blocks(
81
- css=custom_css
82
  )
83
 
84
- with demo:
85
- gr.Markdown(value=description)
86
- with gr.Row():
87
- input_col , settings_col = gr.Column(scale=6), gr.Column(scale=6),
88
- with input_col:
89
- code = gr.Code(lines=28,label='Input', value="def sieve_eratosthenes(n):")
90
- with settings_col:
91
- with gr.Accordion("Generation Settings", open=True):
92
- max_new_tokens= gr.Slider(
93
- minimum=8,
94
- maximum=128,
95
- step=1,
96
- value=48,
97
- label="Max Tokens",
98
- )
99
- temperature = gr.Slider(
100
- minimum=0.1,
101
- maximum=2.5,
102
- step=0.1,
103
- value=0.2,
104
- label="Temperature",
105
- )
106
- repetition_penalty = gr.Slider(
107
- minimum=1.0,
108
- maximum=1.9,
109
- step=0.1,
110
- value=1.0,
111
- label="Repetition Penalty. 1.0 means no penalty.",
112
- )
113
- seed = gr.Slider(
114
- minimum=0,
115
- maximum=1000,
116
- step=1,
117
- label="Random Seed"
118
- )
119
- top_p = gr.Slider(
120
- minimum=0.1,
121
- maximum=1.0,
122
- step=0.1,
123
- value=0.9,
124
- label="Top P",
125
- )
126
- top_k = gr.Slider(
127
- minimum=1,
128
- maximum=64,
129
- step=1,
130
- value=4,
131
- label="Top K",
132
- )
133
- use_cache = gr.Checkbox(
134
- label="Use Cache",
135
- value=True
136
- )
137
-
138
- with gr.Row():
139
- run = gr.Button(elem_id="orange-button", value="Generate More Code")
140
-
141
- # with gr.Row():
142
- # # _, middle_col_row_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
143
- # # with middle_col_row_2:
144
- # output = gr.HTML(label="Generated Code")
145
-
146
- event = run.click(code_generation, [code, max_new_tokens, temperature, seed, top_p, top_k, use_cache, repetition_penalty], code, api_name="predict")
147
-
148
- demo.queue(max_size=40).launch()
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+
4
+ # App Title
5
+ st.set_page_config(page_title="ML Assistant with Replit LLM", layout="wide")
6
+ st.title("🤖 ML Assistant with Replit LLM")
7
+ st.write("Interact with the Replit LLM for machine learning workflows and AI-driven coding assistance.")
8
+
9
+ # Sidebar Configuration
10
+ st.sidebar.title("Configuration")
11
+ api_key = st.sidebar.text_input("Replit LLM API Key", type="password")
12
+ model_name = st.sidebar.text_input("Hugging Face Model Name", "Canstralian/RabbitRedux")
13
+ task_type = st.sidebar.selectbox(
14
+ "Choose a Task",
15
+ ["Text Generation", "Pseudocode to Python", "ML Debugging", "Code Optimization"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
 
18
+ # Ensure API Key is Provided
19
+ if not api_key:
20
+ st.warning("Please provide your Replit LLM API Key in the sidebar to continue.")
21
+ st.stop()
22
+
23
+ # Initialize Replit LLM Pipeline
24
+ try:
25
+ nlp_pipeline = pipeline("text2text-generation", model=model_name)
26
+ st.success("Model loaded successfully!")
27
+ except Exception as e:
28
+ st.error(f"Error loading model: {e}")
29
+ st.stop()
30
+
31
+ # Input Section
32
+ st.subheader("Input Your Query")
33
+ user_input = st.text_area("Enter your query or task description", height=150)
34
+
35
+ # Submit Button
36
+ if st.button("Generate Output"):
37
+ if user_input.strip() == "":
38
+ st.warning("Please enter a valid input.")
39
+ else:
40
+ with st.spinner("Processing..."):
41
+ try:
42
+ # Generate response using Replit LLM
43
+ output = nlp_pipeline(user_input)
44
+ response = output[0]["generated_text"]
45
+ st.subheader("AI Response")
46
+ st.write(response)
47
+ except Exception as e:
48
+ st.error(f"An error occurred: {e}")
49
+
50
+ # Additional ML Features
51
+ st.subheader("Advanced Machine Learning Assistance")
52
+
53
+ if task_type == "Text Generation":
54
+ st.info("Use the input box to generate text-based output.")
55
+ elif task_type == "Pseudocode to Python":
56
+ st.info("Provide pseudocode, and the Replit LLM will attempt to generate Python code.")
57
+ example = st.button("Show Example")
58
+ if example:
59
+ st.code("""
60
+ # Pseudocode
61
+ FOR each item IN list:
62
+ IF item > threshold:
63
+ PRINT "Above Threshold"
64
+
65
+ # Expected Python Output
66
+ for item in my_list:
67
+ if item > threshold:
68
+ print("Above Threshold")
69
+ """)
70
+ elif task_type == "ML Debugging":
71
+ st.info("Describe your ML pipeline error for debugging suggestions.")
72
+ elif task_type == "Code Optimization":
73
+ st.info("Paste your Python code for optimization recommendations.")
74
+ user_code = st.text_area("Paste your Python code", height=200)
75
+ if st.button("Optimize Code"):
76
+ with st.spinner("Analyzing and optimizing..."):
77
+ try:
78
+ optimization_prompt = f"Optimize the following Python code:\n\n{user_code}"
79
+ output = nlp_pipeline(optimization_prompt)
80
+ optimized_code = output[0]["generated_text"]
81
+ st.subheader("Optimized Code")
82
+ st.code(optimized_code)
83
+ except Exception as e:
84
+ st.error(f"An error occurred: {e}")
85
+
86
+ # Footer
87
+ st.write("---")
88
+ st.write("Powered by [Replit LLM](https://replit.com) and [Hugging Face](https://huggingface.co).")