Amelia-James's picture
Create app.py
f635cd7 verified
raw
history blame
3.27 kB
import streamlit as st
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, RagTokenizer, RagRetriever, RagSequenceForGeneration
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType
from dotenv import load_dotenv
import os
# Load environment variables
load_dotenv()
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
# Initialize Milvus connection
connections.connect("default", host="localhost", port="19530")
# Define Milvus schema and collection
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768) # Adjust the dimension based on your model
]
schema = CollectionSchema(fields, "User Data Collection")
collection = Collection(name="user_data", schema=schema)
# Load Hugging Face models
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
tokenizer_rag = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom")
model_rag = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
# Define functions
def generate_embedding(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).numpy().tolist()[0]
def insert_data(user_id, embedding):
collection.insert([user_id, embedding])
def retrieve_relevant_data(query):
query_embedding = generate_embedding(query)
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = collection.search(query_embedding, "embedding", search_params)
return results
def generate_cv(job_description, company_profile=None):
query = job_description
if company_profile:
query += f" Company profile: {company_profile}"
relevant_data = retrieve_relevant_data(query)
context = " ".join([data.text for data in relevant_data])
inputs = tokenizer_rag(query, return_tensors="pt")
context_inputs = tokenizer_rag(context, return_tensors="pt")
outputs = model_rag.generate(input_ids=inputs['input_ids'], context_input_ids=context_inputs['input_ids'])
return tokenizer_rag.decode(outputs[0], skip_special_tokens=True)
# Streamlit UI
st.title("Custom CV Generator")
st.sidebar.header("Input Data")
skills = st.sidebar.text_input("Enter your skills")
experience = st.sidebar.text_input("Enter your experience")
education = st.sidebar.text_input("Enter your education")
job_description = st.sidebar.text_area("Enter job description")
company_profile = st.sidebar.text_area("Enter company profile (optional)")
if st.sidebar.button("Generate CV"):
# Insert user data (assuming single user for simplicity)
user_data = f"Skills: {skills}. Experience: {experience}. Education: {education}."
user_id = 1 # Example user ID
user_embedding = generate_embedding(user_data)
insert_data(user_id, user_embedding)
# Generate CV
cv_text = generate_cv(job_description, company_profile)
st.write("Your Tailored CV:")
st.write(cv_text)