File size: 2,154 Bytes
ce3c914
 
 
2c9b240
 
ce3c914
2c9b240
ce3c914
 
2d3701f
ce3c914
881f44e
 
9978119
2c9b240
9978119
881f44e
ce3c914
 
2c9b240
 
 
 
 
 
ce3c914
2c9b240
ce3c914
2c9b240
 
 
 
 
 
 
 
 
 
 
 
ce3c914
 
 
 
 
2c9b240
ce3c914
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set Streamlit page configuration
st.set_page_config(page_title="SQLCoder: AI-Powered SQL Code Generator", layout="wide")

# Function to load the model and tokenizer
@st.cache_resource
def load_model():
    model_name = "defog/sqlcoder"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        offload_folder="./offload_weights",  # Ensure weights are stored properly
        low_cpu_mem_usage=True,
    )
    return tokenizer, model

# Load the model and tokenizer
try:
    tokenizer, model = load_model()
except Exception as e:
    st.error(f"Error loading model: {e}")
    st.stop()

# Function to generate code using the model
def generate_code(prompt, max_length=150):
    try:
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,  # Adjust temperature for creativity
            top_k=50,  # Limit sampling to top k tokens
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        raise RuntimeError(f"Error during code generation: {e}")

# Streamlit app layout
st.title("SQLCoder: AI-Powered SQL Code Generator")
st.write("Generate SQL queries and code snippets using the SQLCoder model.")

# Input prompt
prompt = st.text_area("Enter your query or code prompt:", height=150)
max_length = st.slider("Maximum Output Length", min_value=50, max_value=300, value=150, step=10)

# Generate button
if st.button("Generate Code"):
    if prompt.strip():
        with st.spinner("Generating code..."):
            try:
                generated_code = generate_code(prompt, max_length=max_length)
                st.text_area("Generated Code:", value=generated_code, height=200)
            except Exception as e:
                st.error(f"Error generating code: {e}")
    else:
        st.warning("Please enter a prompt before generating code.")