|
|
|
|
|
import os |
|
|
|
|
|
from huggingface_hub import login |
|
|
login(token=os.getenv("HUGGINGFACEHUB_API_KEY")) |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
def get_device_type() -> str: |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
|
|
|
elif torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
|
|
|
else: |
|
|
return "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name="meta-llama/Llama-3.2-3B-Instruct" |
|
|
|
|
|
device = get_device_type() |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) |
|
|
|
|
|
system_prompt = """ |
|
|
You are a customer support officer. Extract the transaction ID from the USER INPUT and determine its type. |
|
|
|
|
|
Extraction rules: |
|
|
- Look for words starting with 'payout' or 'payment'. |
|
|
- The next character must be a dash ('-'). |
|
|
- There must be at least one digit or character after the dash. |
|
|
- The transaction ID must appear exactly in the USER INPUT. |
|
|
- If found, set found to true; otherwise, set found to false. |
|
|
|
|
|
Type rules: |
|
|
- If the transaction ID starts with 'payout', type is payout. |
|
|
- If it starts with 'payment', type is payment. |
|
|
|
|
|
===>USER INPUT BEGINS |
|
|
{input} |
|
|
<===USER INPUT ENDS |
|
|
|
|
|
Respond in valid JSON with these fields: |
|
|
found: (boolean) Whether a valid transaction ID was found. |
|
|
transaction_id: (string, if found) The extracted transaction ID. |
|
|
transaction_type: (string, if found) The transaction type in lowercase. |
|
|
justification: (string) Explain how you determined the transaction ID and type. If not found, do not fabricate an explanation. |
|
|
Return only valid JSON and nothing else. |
|
|
""" |
|
|
|
|
|
examples = [ |
|
|
"My transaction payment-a1c1 failed", |
|
|
"Why is my withdrawal payout-b2c2 pending for 3 days", |
|
|
"There is an issue with my transaction payout-87l2k3", |
|
|
"I am having trouble with my transaction", |
|
|
] |
|
|
|
|
|
def predict(message, history): |
|
|
|
|
|
sys_prompt = system_prompt.replace("{input}", message) |
|
|
if not history or history[0].get("role") != "system": |
|
|
history = [{"role": "system", "content": sys_prompt}] + history |
|
|
else: |
|
|
history[0]["content"] = sys_prompt |
|
|
|
|
|
history.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template(history, tokenize=False) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
outputs = model.generate(**inputs, max_new_tokens=100) |
|
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "<|im_start|>assistant" in decoded: |
|
|
response = decoded.split("<|im_start|>assistant")[-1] |
|
|
|
|
|
response = response.replace("<|im_end|>", "").strip() |
|
|
else: |
|
|
|
|
|
response = decoded.strip() |
|
|
|
|
|
return response |
|
|
|
|
|
demo = gr.ChatInterface(predict, type="messages", examples=examples) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|
|
|
|