Spaces:
Sleeping
Sleeping
import logging | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import gradio as gr | |
# Set up logging | |
logging.basicConfig( | |
filename="app.log", | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
# Select the model you want to use (LLaMA, GPT, or CodeT5) | |
MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct" | |
# Load the model and tokenizer | |
logging.info(f"Loading the model: {MODEL_NAME}...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
logging.info("Model loaded successfully.") | |
except Exception as e: | |
logging.error(f"Error loading the model: {e}") | |
raise | |
# Define a function to generate test cases | |
def generate_test_cases(api_info): | |
logging.info(f"Generating test cases for API info: {api_info}") | |
try: | |
prompt = ( | |
f"Generate API test cases for the following API:\n\n{api_info}\n\n" | |
f"Test cases should include:\n- Happy path\n- Negative tests\n- Edge cases" | |
) | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) | |
# Generate output from the model | |
outputs = model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], # Provide attention mask explicitly | |
max_length=512, | |
num_return_sequences=1, | |
do_sample=True | |
) | |
# Decode the generated text | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logging.info("Test cases generated successfully.") | |
return generated_text | |
except Exception as e: | |
logging.error(f"Error generating test cases: {e}") | |
return "An error occurred while generating test cases." | |
# Process input and generate output | |
def process_input(url, method, headers, payload): | |
try: | |
logging.info("Received user input.") | |
api_info = f"URL: {url}\nMethod: {method}\nHeaders: {headers}\nPayload: {payload}" | |
logging.debug(f"Formatted API info: {api_info}") | |
test_cases = generate_test_cases(api_info) | |
return test_cases | |
except Exception as e: | |
logging.error(f"Error processing input: {e}") | |
return "An error occurred. Please check the input format and try again." | |
# Define Gradio interface | |
interface = gr.Interface( | |
fn=process_input, | |
inputs=[ | |
gr.Textbox(label="API URL"), | |
gr.Textbox(label="HTTP Method"), | |
gr.Textbox(label="Headers (JSON format)"), | |
gr.Textbox(label="Payload (JSON format)"), | |
], | |
outputs="text", | |
title=f"API Test Case Generator ({MODEL_NAME})" | |
) | |
# Launch Gradio app | |
if __name__ == "__main__": | |
try: | |
logging.info("Starting the Gradio app...") | |
interface.launch() | |
logging.info("Gradio app launched successfully.") | |
except Exception as e: | |
logging.error(f"Error launching the Gradio app: {e}") | |