train.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import sys
  2. import pandas as pd
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.feature_extraction.text import TfidfVectorizer
  5. from sklearn.preprocessing import StandardScaler
  6. from sklearn.compose import ColumnTransformer
  7. from sklearn.pipeline import Pipeline
  8. from sklearn.linear_model import LogisticRegression
  9. from sklearn.metrics import classification_report
  10. from sklearn.base import BaseEstimator, TransformerMixin
  11. import joblib
  12. # Custom transformer to reshape data
  13. class ReshapeTransformer(BaseEstimator, TransformerMixin):
  14. def fit(self, X, y=None):
  15. return self
  16. def transform(self, X):
  17. return X.values.reshape(-1, 1)
  18. # Load the dataset
  19. df = pd.read_csv('Learn.csv') # assuming the file name is transactions.csv
  20. # Display the first few rows of the dataset
  21. print(df.head())
  22. # Split data into features and target
  23. X = df[['merchant_name', 'amount']]
  24. y = df['category']
  25. # Split the data into training and testing sets
  26. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  27. # Define the preprocessing for both text and numerical data
  28. preprocessor = ColumnTransformer(
  29. transformers=[
  30. ('merchant_tfidf', TfidfVectorizer(), 'merchant_name'),
  31. ('amount_scaler', Pipeline([
  32. ('reshape', ReshapeTransformer()),
  33. ('scaler', StandardScaler())
  34. ]), 'amount')
  35. ]
  36. )
  37. # Create the full pipeline with preprocessing and model
  38. pipeline = Pipeline([
  39. ('preprocessor', preprocessor),
  40. ('classifier', LogisticRegression(max_iter=1000))
  41. ])
  42. # Train the model
  43. pipeline.fit(X_train, y_train)
  44. # Predict on the test set
  45. y_pred = pipeline.predict(X_test)
  46. # Save the model to a file
  47. joblib.dump(pipeline, 'model_pipeline.pkl')
  48. print("Model saved successfully.")
  49. # Load the model from the file
  50. loaded_pipeline = joblib.load('model_pipeline.pkl')
  51. print("Model loaded successfully.")
  52. if(len(sys.argv) > 1):
  53. pred = loaded_pipeline.predict(pd.DataFrame({
  54. 'merchant_name' : [sys.argv[1]],
  55. 'amount' : ['5300']
  56. }))
  57. print (f"Merchant {sys.argv[1]} -> {pred}")
  58. # Print the classification report
  59. print(classification_report(y_test, y_pred, zero_division=0))